Cleaned up the directories
This commit is contained in:
parent
f708506d68
commit
a683fcffea
1340 changed files with 554582 additions and 6840 deletions
|
@ -0,0 +1,61 @@
|
|||
# dialects/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Callable
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .. import util
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..engine.interfaces import Dialect
|
||||
|
||||
__all__ = ("mssql", "mysql", "oracle", "postgresql", "sqlite")
|
||||
|
||||
|
||||
def _auto_fn(name: str) -> Optional[Callable[[], Type[Dialect]]]:
|
||||
"""default dialect importer.
|
||||
|
||||
plugs into the :class:`.PluginLoader`
|
||||
as a first-hit system.
|
||||
|
||||
"""
|
||||
if "." in name:
|
||||
dialect, driver = name.split(".")
|
||||
else:
|
||||
dialect = name
|
||||
driver = "base"
|
||||
|
||||
try:
|
||||
if dialect == "mariadb":
|
||||
# it's "OK" for us to hardcode here since _auto_fn is already
|
||||
# hardcoded. if mysql / mariadb etc were third party dialects
|
||||
# they would just publish all the entrypoints, which would actually
|
||||
# look much nicer.
|
||||
module = __import__(
|
||||
"sqlalchemy.dialects.mysql.mariadb"
|
||||
).dialects.mysql.mariadb
|
||||
return module.loader(driver) # type: ignore
|
||||
else:
|
||||
module = __import__("sqlalchemy.dialects.%s" % (dialect,)).dialects
|
||||
module = getattr(module, dialect)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
if hasattr(module, driver):
|
||||
module = getattr(module, driver)
|
||||
return lambda: module.dialect
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
registry = util.PluginLoader("sqlalchemy.dialects", auto_fn=_auto_fn)
|
||||
|
||||
plugins = util.PluginLoader("sqlalchemy.plugins")
|
|
@ -0,0 +1,25 @@
|
|||
# dialects/_typing.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Iterable
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Union
|
||||
|
||||
from ..sql._typing import _DDLColumnArgument
|
||||
from ..sql.elements import DQLDMLClauseElement
|
||||
from ..sql.schema import ColumnCollectionConstraint
|
||||
from ..sql.schema import Index
|
||||
|
||||
|
||||
_OnConflictConstraintT = Union[str, ColumnCollectionConstraint, Index, None]
|
||||
_OnConflictIndexElementsT = Optional[Iterable[_DDLColumnArgument]]
|
||||
_OnConflictIndexWhereT = Optional[DQLDMLClauseElement]
|
||||
_OnConflictSetT = Optional[Mapping[Any, Any]]
|
||||
_OnConflictWhereT = Union[DQLDMLClauseElement, str, None]
|
|
@ -0,0 +1,88 @@
|
|||
# dialects/mssql/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from . import aioodbc # noqa
|
||||
from . import base # noqa
|
||||
from . import pymssql # noqa
|
||||
from . import pyodbc # noqa
|
||||
from .base import BIGINT
|
||||
from .base import BINARY
|
||||
from .base import BIT
|
||||
from .base import CHAR
|
||||
from .base import DATE
|
||||
from .base import DATETIME
|
||||
from .base import DATETIME2
|
||||
from .base import DATETIMEOFFSET
|
||||
from .base import DECIMAL
|
||||
from .base import DOUBLE_PRECISION
|
||||
from .base import FLOAT
|
||||
from .base import IMAGE
|
||||
from .base import INTEGER
|
||||
from .base import JSON
|
||||
from .base import MONEY
|
||||
from .base import NCHAR
|
||||
from .base import NTEXT
|
||||
from .base import NUMERIC
|
||||
from .base import NVARCHAR
|
||||
from .base import REAL
|
||||
from .base import ROWVERSION
|
||||
from .base import SMALLDATETIME
|
||||
from .base import SMALLINT
|
||||
from .base import SMALLMONEY
|
||||
from .base import SQL_VARIANT
|
||||
from .base import TEXT
|
||||
from .base import TIME
|
||||
from .base import TIMESTAMP
|
||||
from .base import TINYINT
|
||||
from .base import UNIQUEIDENTIFIER
|
||||
from .base import VARBINARY
|
||||
from .base import VARCHAR
|
||||
from .base import XML
|
||||
from ...sql import try_cast
|
||||
|
||||
|
||||
base.dialect = dialect = pyodbc.dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
"JSON",
|
||||
"INTEGER",
|
||||
"BIGINT",
|
||||
"SMALLINT",
|
||||
"TINYINT",
|
||||
"VARCHAR",
|
||||
"NVARCHAR",
|
||||
"CHAR",
|
||||
"NCHAR",
|
||||
"TEXT",
|
||||
"NTEXT",
|
||||
"DECIMAL",
|
||||
"NUMERIC",
|
||||
"FLOAT",
|
||||
"DATETIME",
|
||||
"DATETIME2",
|
||||
"DATETIMEOFFSET",
|
||||
"DATE",
|
||||
"DOUBLE_PRECISION",
|
||||
"TIME",
|
||||
"SMALLDATETIME",
|
||||
"BINARY",
|
||||
"VARBINARY",
|
||||
"BIT",
|
||||
"REAL",
|
||||
"IMAGE",
|
||||
"TIMESTAMP",
|
||||
"ROWVERSION",
|
||||
"MONEY",
|
||||
"SMALLMONEY",
|
||||
"UNIQUEIDENTIFIER",
|
||||
"SQL_VARIANT",
|
||||
"XML",
|
||||
"dialect",
|
||||
"try_cast",
|
||||
)
|
|
@ -0,0 +1,64 @@
|
|||
# dialects/mssql/aioodbc.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
r"""
|
||||
.. dialect:: mssql+aioodbc
|
||||
:name: aioodbc
|
||||
:dbapi: aioodbc
|
||||
:connectstring: mssql+aioodbc://<username>:<password>@<dsnname>
|
||||
:url: https://pypi.org/project/aioodbc/
|
||||
|
||||
|
||||
Support for the SQL Server database in asyncio style, using the aioodbc
|
||||
driver which itself is a thread-wrapper around pyodbc.
|
||||
|
||||
.. versionadded:: 2.0.23 Added the mssql+aioodbc dialect which builds
|
||||
on top of the pyodbc and general aio* dialect architecture.
|
||||
|
||||
Using a special asyncio mediation layer, the aioodbc dialect is usable
|
||||
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
|
||||
extension package.
|
||||
|
||||
Most behaviors and caveats for this driver are the same as that of the
|
||||
pyodbc dialect used on SQL Server; see :ref:`mssql_pyodbc` for general
|
||||
background.
|
||||
|
||||
This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function; connection
|
||||
styles are otherwise equivalent to those documented in the pyodbc section::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine(
|
||||
"mssql+aioodbc://scott:tiger@mssql2017:1433/test?"
|
||||
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||
)
|
||||
|
||||
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from .pyodbc import MSDialect_pyodbc
|
||||
from .pyodbc import MSExecutionContext_pyodbc
|
||||
from ...connectors.aioodbc import aiodbcConnector
|
||||
|
||||
|
||||
class MSExecutionContext_aioodbc(MSExecutionContext_pyodbc):
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(server_side=True)
|
||||
|
||||
|
||||
class MSDialectAsync_aioodbc(aiodbcConnector, MSDialect_pyodbc):
|
||||
driver = "aioodbc"
|
||||
|
||||
supports_statement_cache = True
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_aioodbc
|
||||
|
||||
|
||||
dialect = MSDialectAsync_aioodbc
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,254 @@
|
|||
# dialects/mssql/information_schema.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import cast
|
||||
from ... import Column
|
||||
from ... import MetaData
|
||||
from ... import Table
|
||||
from ...ext.compiler import compiles
|
||||
from ...sql import expression
|
||||
from ...types import Boolean
|
||||
from ...types import Integer
|
||||
from ...types import Numeric
|
||||
from ...types import NVARCHAR
|
||||
from ...types import String
|
||||
from ...types import TypeDecorator
|
||||
from ...types import Unicode
|
||||
|
||||
|
||||
ischema = MetaData()
|
||||
|
||||
|
||||
class CoerceUnicode(TypeDecorator):
|
||||
impl = Unicode
|
||||
cache_ok = True
|
||||
|
||||
def bind_expression(self, bindvalue):
|
||||
return _cast_on_2005(bindvalue)
|
||||
|
||||
|
||||
class _cast_on_2005(expression.ColumnElement):
|
||||
def __init__(self, bindvalue):
|
||||
self.bindvalue = bindvalue
|
||||
|
||||
|
||||
@compiles(_cast_on_2005)
|
||||
def _compile(element, compiler, **kw):
|
||||
from . import base
|
||||
|
||||
if (
|
||||
compiler.dialect.server_version_info is None
|
||||
or compiler.dialect.server_version_info < base.MS_2005_VERSION
|
||||
):
|
||||
return compiler.process(element.bindvalue, **kw)
|
||||
else:
|
||||
return compiler.process(cast(element.bindvalue, Unicode), **kw)
|
||||
|
||||
|
||||
schemata = Table(
|
||||
"SCHEMATA",
|
||||
ischema,
|
||||
Column("CATALOG_NAME", CoerceUnicode, key="catalog_name"),
|
||||
Column("SCHEMA_NAME", CoerceUnicode, key="schema_name"),
|
||||
Column("SCHEMA_OWNER", CoerceUnicode, key="schema_owner"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
tables = Table(
|
||||
"TABLES",
|
||||
ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("TABLE_TYPE", CoerceUnicode, key="table_type"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
columns = Table(
|
||||
"COLUMNS",
|
||||
ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column(
|
||||
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
|
||||
),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
mssql_temp_table_columns = Table(
|
||||
"COLUMNS",
|
||||
ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("IS_NULLABLE", Integer, key="is_nullable"),
|
||||
Column("DATA_TYPE", String, key="data_type"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
Column(
|
||||
"CHARACTER_MAXIMUM_LENGTH", Integer, key="character_maximum_length"
|
||||
),
|
||||
Column("NUMERIC_PRECISION", Integer, key="numeric_precision"),
|
||||
Column("NUMERIC_SCALE", Integer, key="numeric_scale"),
|
||||
Column("COLUMN_DEFAULT", Integer, key="column_default"),
|
||||
Column("COLLATION_NAME", String, key="collation_name"),
|
||||
schema="tempdb.INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
constraints = Table(
|
||||
"TABLE_CONSTRAINTS",
|
||||
ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("CONSTRAINT_TYPE", CoerceUnicode, key="constraint_type"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
column_constraints = Table(
|
||||
"CONSTRAINT_COLUMN_USAGE",
|
||||
ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
key_constraints = Table(
|
||||
"KEY_COLUMN_USAGE",
|
||||
ischema,
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("COLUMN_NAME", CoerceUnicode, key="column_name"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
|
||||
Column("ORDINAL_POSITION", Integer, key="ordinal_position"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
ref_constraints = Table(
|
||||
"REFERENTIAL_CONSTRAINTS",
|
||||
ischema,
|
||||
Column("CONSTRAINT_CATALOG", CoerceUnicode, key="constraint_catalog"),
|
||||
Column("CONSTRAINT_SCHEMA", CoerceUnicode, key="constraint_schema"),
|
||||
Column("CONSTRAINT_NAME", CoerceUnicode, key="constraint_name"),
|
||||
# TODO: is CATLOG misspelled ?
|
||||
Column(
|
||||
"UNIQUE_CONSTRAINT_CATLOG",
|
||||
CoerceUnicode,
|
||||
key="unique_constraint_catalog",
|
||||
),
|
||||
Column(
|
||||
"UNIQUE_CONSTRAINT_SCHEMA",
|
||||
CoerceUnicode,
|
||||
key="unique_constraint_schema",
|
||||
),
|
||||
Column(
|
||||
"UNIQUE_CONSTRAINT_NAME", CoerceUnicode, key="unique_constraint_name"
|
||||
),
|
||||
Column("MATCH_OPTION", String, key="match_option"),
|
||||
Column("UPDATE_RULE", String, key="update_rule"),
|
||||
Column("DELETE_RULE", String, key="delete_rule"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
views = Table(
|
||||
"VIEWS",
|
||||
ischema,
|
||||
Column("TABLE_CATALOG", CoerceUnicode, key="table_catalog"),
|
||||
Column("TABLE_SCHEMA", CoerceUnicode, key="table_schema"),
|
||||
Column("TABLE_NAME", CoerceUnicode, key="table_name"),
|
||||
Column("VIEW_DEFINITION", CoerceUnicode, key="view_definition"),
|
||||
Column("CHECK_OPTION", String, key="check_option"),
|
||||
Column("IS_UPDATABLE", String, key="is_updatable"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
computed_columns = Table(
|
||||
"computed_columns",
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("is_computed", Boolean),
|
||||
Column("is_persisted", Boolean),
|
||||
Column("definition", CoerceUnicode),
|
||||
schema="sys",
|
||||
)
|
||||
|
||||
sequences = Table(
|
||||
"SEQUENCES",
|
||||
ischema,
|
||||
Column("SEQUENCE_CATALOG", CoerceUnicode, key="sequence_catalog"),
|
||||
Column("SEQUENCE_SCHEMA", CoerceUnicode, key="sequence_schema"),
|
||||
Column("SEQUENCE_NAME", CoerceUnicode, key="sequence_name"),
|
||||
schema="INFORMATION_SCHEMA",
|
||||
)
|
||||
|
||||
|
||||
class NumericSqlVariant(TypeDecorator):
|
||||
r"""This type casts sql_variant columns in the identity_columns view
|
||||
to numeric. This is required because:
|
||||
|
||||
* pyodbc does not support sql_variant
|
||||
* pymssql under python 2 return the byte representation of the number,
|
||||
int 1 is returned as "\x01\x00\x00\x00". On python 3 it returns the
|
||||
correct value as string.
|
||||
"""
|
||||
|
||||
impl = Unicode
|
||||
cache_ok = True
|
||||
|
||||
def column_expression(self, colexpr):
|
||||
return cast(colexpr, Numeric(38, 0))
|
||||
|
||||
|
||||
identity_columns = Table(
|
||||
"identity_columns",
|
||||
ischema,
|
||||
Column("object_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("is_identity", Boolean),
|
||||
Column("seed_value", NumericSqlVariant),
|
||||
Column("increment_value", NumericSqlVariant),
|
||||
Column("last_value", NumericSqlVariant),
|
||||
Column("is_not_for_replication", Boolean),
|
||||
schema="sys",
|
||||
)
|
||||
|
||||
|
||||
class NVarcharSqlVariant(TypeDecorator):
|
||||
"""This type casts sql_variant columns in the extended_properties view
|
||||
to nvarchar. This is required because pyodbc does not support sql_variant
|
||||
"""
|
||||
|
||||
impl = Unicode
|
||||
cache_ok = True
|
||||
|
||||
def column_expression(self, colexpr):
|
||||
return cast(colexpr, NVARCHAR)
|
||||
|
||||
|
||||
extended_properties = Table(
|
||||
"extended_properties",
|
||||
ischema,
|
||||
Column("class", Integer), # TINYINT
|
||||
Column("class_desc", CoerceUnicode),
|
||||
Column("major_id", Integer),
|
||||
Column("minor_id", Integer),
|
||||
Column("name", CoerceUnicode),
|
||||
Column("value", NVarcharSqlVariant),
|
||||
schema="sys",
|
||||
)
|
|
@ -0,0 +1,133 @@
|
|||
# dialects/mssql/json.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
# technically, all the dialect-specific datatypes that don't have any special
|
||||
# behaviors would be private with names like _MSJson. However, we haven't been
|
||||
# doing this for mysql.JSON or sqlite.JSON which both have JSON / JSONIndexType
|
||||
# / JSONPathType in their json.py files, so keep consistent with that
|
||||
# sub-convention for now. A future change can update them all to be
|
||||
# package-private at once.
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""MSSQL JSON type.
|
||||
|
||||
MSSQL supports JSON-formatted data as of SQL Server 2016.
|
||||
|
||||
The :class:`_mssql.JSON` datatype at the DDL level will represent the
|
||||
datatype as ``NVARCHAR(max)``, but provides for JSON-level comparison
|
||||
functions as well as Python coercion behavior.
|
||||
|
||||
:class:`_mssql.JSON` is used automatically whenever the base
|
||||
:class:`_types.JSON` datatype is used against a SQL Server backend.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON` - main documentation for the generic
|
||||
cross-platform JSON datatype.
|
||||
|
||||
The :class:`_mssql.JSON` type supports persistence of JSON values
|
||||
as well as the core index operations provided by :class:`_types.JSON`
|
||||
datatype, by adapting the operations to render the ``JSON_VALUE``
|
||||
or ``JSON_QUERY`` functions at the database level.
|
||||
|
||||
The SQL Server :class:`_mssql.JSON` type necessarily makes use of the
|
||||
``JSON_QUERY`` and ``JSON_VALUE`` functions when querying for elements
|
||||
of a JSON object. These two functions have a major restriction in that
|
||||
they are **mutually exclusive** based on the type of object to be returned.
|
||||
The ``JSON_QUERY`` function **only** returns a JSON dictionary or list,
|
||||
but not an individual string, numeric, or boolean element; the
|
||||
``JSON_VALUE`` function **only** returns an individual string, numeric,
|
||||
or boolean element. **both functions either return NULL or raise
|
||||
an error if they are not used against the correct expected value**.
|
||||
|
||||
To handle this awkward requirement, indexed access rules are as follows:
|
||||
|
||||
1. When extracting a sub element from a JSON that is itself a JSON
|
||||
dictionary or list, the :meth:`_types.JSON.Comparator.as_json` accessor
|
||||
should be used::
|
||||
|
||||
stmt = select(
|
||||
data_table.c.data["some key"].as_json()
|
||||
).where(
|
||||
data_table.c.data["some key"].as_json() == {"sub": "structure"}
|
||||
)
|
||||
|
||||
2. When extracting a sub element from a JSON that is a plain boolean,
|
||||
string, integer, or float, use the appropriate method among
|
||||
:meth:`_types.JSON.Comparator.as_boolean`,
|
||||
:meth:`_types.JSON.Comparator.as_string`,
|
||||
:meth:`_types.JSON.Comparator.as_integer`,
|
||||
:meth:`_types.JSON.Comparator.as_float`::
|
||||
|
||||
stmt = select(
|
||||
data_table.c.data["some key"].as_string()
|
||||
).where(
|
||||
data_table.c.data["some key"].as_string() == "some string"
|
||||
)
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
|
||||
"""
|
||||
|
||||
# note there was a result processor here that was looking for "number",
|
||||
# but none of the tests seem to exercise it.
|
||||
|
||||
|
||||
# Note: these objects currently match exactly those of MySQL, however since
|
||||
# these are not generalizable to all JSON implementations, remain separately
|
||||
# implemented for each dialect.
|
||||
class _FormatTypeMixin:
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join(
|
||||
[
|
||||
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
|
||||
for elem in value
|
||||
]
|
||||
)
|
||||
)
|
|
@ -0,0 +1,155 @@
|
|||
# dialects/mssql/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import Integer
|
||||
from ... import create_engine
|
||||
from ... import exc
|
||||
from ...schema import Column
|
||||
from ...schema import DropConstraint
|
||||
from ...schema import ForeignKeyConstraint
|
||||
from ...schema import MetaData
|
||||
from ...schema import Table
|
||||
from ...testing.provision import create_db
|
||||
from ...testing.provision import drop_all_schema_objects_pre_tables
|
||||
from ...testing.provision import drop_db
|
||||
from ...testing.provision import generate_driver_url
|
||||
from ...testing.provision import get_temp_table_name
|
||||
from ...testing.provision import log
|
||||
from ...testing.provision import normalize_sequence
|
||||
from ...testing.provision import run_reap_dbs
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
|
||||
|
||||
@generate_driver_url.for_db("mssql")
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
|
||||
new_url = url.set(drivername="%s+%s" % (backend, driver))
|
||||
|
||||
if driver not in ("pyodbc", "aioodbc"):
|
||||
new_url = new_url.set(query="")
|
||||
|
||||
if driver == "aioodbc":
|
||||
new_url = new_url.update_query_dict({"MARS_Connection": "Yes"})
|
||||
|
||||
if query_str:
|
||||
new_url = new_url.update_query_string(query_str)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return new_url
|
||||
|
||||
|
||||
@create_db.for_db("mssql")
|
||||
def _mssql_create_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
conn.exec_driver_sql("create database %s" % ident)
|
||||
conn.exec_driver_sql(
|
||||
"ALTER DATABASE %s SET ALLOW_SNAPSHOT_ISOLATION ON" % ident
|
||||
)
|
||||
conn.exec_driver_sql(
|
||||
"ALTER DATABASE %s SET READ_COMMITTED_SNAPSHOT ON" % ident
|
||||
)
|
||||
conn.exec_driver_sql("use %s" % ident)
|
||||
conn.exec_driver_sql("create schema test_schema")
|
||||
conn.exec_driver_sql("create schema test_schema_2")
|
||||
|
||||
|
||||
@drop_db.for_db("mssql")
|
||||
def _mssql_drop_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
_mssql_drop_ignore(conn, ident)
|
||||
|
||||
|
||||
def _mssql_drop_ignore(conn, ident):
|
||||
try:
|
||||
# typically when this happens, we can't KILL the session anyway,
|
||||
# so let the cleanup process drop the DBs
|
||||
# for row in conn.exec_driver_sql(
|
||||
# "select session_id from sys.dm_exec_sessions "
|
||||
# "where database_id=db_id('%s')" % ident):
|
||||
# log.info("killing SQL server session %s", row['session_id'])
|
||||
# conn.exec_driver_sql("kill %s" % row['session_id'])
|
||||
conn.exec_driver_sql("drop database %s" % ident)
|
||||
log.info("Reaped db: %s", ident)
|
||||
return True
|
||||
except exc.DatabaseError as err:
|
||||
log.warning("couldn't drop db: %s", err)
|
||||
return False
|
||||
|
||||
|
||||
@run_reap_dbs.for_db("mssql")
|
||||
def _reap_mssql_dbs(url, idents):
|
||||
log.info("db reaper connecting to %r", url)
|
||||
eng = create_engine(url)
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
log.info("identifiers in file: %s", ", ".join(idents))
|
||||
|
||||
to_reap = conn.exec_driver_sql(
|
||||
"select d.name from sys.databases as d where name "
|
||||
"like 'TEST_%' and not exists (select session_id "
|
||||
"from sys.dm_exec_sessions "
|
||||
"where database_id=d.database_id)"
|
||||
)
|
||||
all_names = {dbname.lower() for (dbname,) in to_reap}
|
||||
to_drop = set()
|
||||
for name in all_names:
|
||||
if name in idents:
|
||||
to_drop.add(name)
|
||||
|
||||
dropped = total = 0
|
||||
for total, dbname in enumerate(to_drop, 1):
|
||||
if _mssql_drop_ignore(conn, dbname):
|
||||
dropped += 1
|
||||
log.info(
|
||||
"Dropped %d out of %d stale databases detected", dropped, total
|
||||
)
|
||||
|
||||
|
||||
@temp_table_keyword_args.for_db("mssql")
|
||||
def _mssql_temp_table_keyword_args(cfg, eng):
|
||||
return {}
|
||||
|
||||
|
||||
@get_temp_table_name.for_db("mssql")
|
||||
def _mssql_get_temp_table_name(cfg, eng, base_name):
|
||||
return "##" + base_name
|
||||
|
||||
|
||||
@drop_all_schema_objects_pre_tables.for_db("mssql")
|
||||
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
inspector = inspect(conn)
|
||||
for schema in (None, "dbo", cfg.test_schema, cfg.test_schema_2):
|
||||
for tname in inspector.get_table_names(schema=schema):
|
||||
tb = Table(
|
||||
tname,
|
||||
MetaData(),
|
||||
Column("x", Integer),
|
||||
Column("y", Integer),
|
||||
schema=schema,
|
||||
)
|
||||
for fk in inspect(conn).get_foreign_keys(tname, schema=schema):
|
||||
conn.execute(
|
||||
DropConstraint(
|
||||
ForeignKeyConstraint(
|
||||
[tb.c.x], [tb.c.y], name=fk["name"]
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@normalize_sequence.for_db("mssql")
|
||||
def normalize_sequence(cfg, sequence):
|
||||
if sequence.start is None:
|
||||
sequence.start = 1
|
||||
return sequence
|
|
@ -0,0 +1,125 @@
|
|||
# dialects/mssql/pymssql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
.. dialect:: mssql+pymssql
|
||||
:name: pymssql
|
||||
:dbapi: pymssql
|
||||
:connectstring: mssql+pymssql://<username>:<password>@<freetds_name>/?charset=utf8
|
||||
|
||||
pymssql is a Python module that provides a Python DBAPI interface around
|
||||
`FreeTDS <https://www.freetds.org/>`_.
|
||||
|
||||
.. versionchanged:: 2.0.5
|
||||
|
||||
pymssql was restored to SQLAlchemy's continuous integration testing
|
||||
|
||||
|
||||
""" # noqa
|
||||
import re
|
||||
|
||||
from .base import MSDialect
|
||||
from .base import MSIdentifierPreparer
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...engine import processors
|
||||
|
||||
|
||||
class _MSNumeric_pymssql(sqltypes.Numeric):
|
||||
def result_processor(self, dialect, type_):
|
||||
if not self.asdecimal:
|
||||
return processors.to_float
|
||||
else:
|
||||
return sqltypes.Numeric.result_processor(self, dialect, type_)
|
||||
|
||||
|
||||
class MSIdentifierPreparer_pymssql(MSIdentifierPreparer):
|
||||
def __init__(self, dialect):
|
||||
super().__init__(dialect)
|
||||
# pymssql has the very unusual behavior that it uses pyformat
|
||||
# yet does not require that percent signs be doubled
|
||||
self._double_percents = False
|
||||
|
||||
|
||||
class MSDialect_pymssql(MSDialect):
|
||||
supports_statement_cache = True
|
||||
supports_native_decimal = True
|
||||
supports_native_uuid = True
|
||||
driver = "pymssql"
|
||||
|
||||
preparer = MSIdentifierPreparer_pymssql
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{sqltypes.Numeric: _MSNumeric_pymssql, sqltypes.Float: sqltypes.Float},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
module = __import__("pymssql")
|
||||
# pymmsql < 2.1.1 doesn't have a Binary method. we use string
|
||||
client_ver = tuple(int(x) for x in module.__version__.split("."))
|
||||
if client_ver < (2, 1, 1):
|
||||
# TODO: monkeypatching here is less than ideal
|
||||
module.Binary = lambda x: x if hasattr(x, "decode") else str(x)
|
||||
|
||||
if client_ver < (1,):
|
||||
util.warn(
|
||||
"The pymssql dialect expects at least "
|
||||
"the 1.0 series of the pymssql DBAPI."
|
||||
)
|
||||
return module
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
vers = connection.exec_driver_sql("select @@version").scalar()
|
||||
m = re.match(r"Microsoft .*? - (\d+)\.(\d+)\.(\d+)\.(\d+)", vers)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3, 4))
|
||||
else:
|
||||
return None
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user")
|
||||
opts.update(url.query)
|
||||
port = opts.pop("port", None)
|
||||
if port and "host" in opts:
|
||||
opts["host"] = "%s:%s" % (opts["host"], port)
|
||||
return ([], opts)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
for msg in (
|
||||
"Adaptive Server connection timed out",
|
||||
"Net-Lib error during Connection reset by peer",
|
||||
"message 20003", # connection timeout
|
||||
"Error 10054",
|
||||
"Not connected to any MS SQL server",
|
||||
"Connection is closed",
|
||||
"message 20006", # Write to the server failed
|
||||
"message 20017", # Unexpected EOF from the server
|
||||
"message 20047", # DBPROCESS is dead or not enabled
|
||||
):
|
||||
if msg in str(e):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return super().get_isolation_level_values(dbapi_connection) + [
|
||||
"AUTOCOMMIT"
|
||||
]
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit(True)
|
||||
else:
|
||||
dbapi_connection.autocommit(False)
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
|
||||
dialect = MSDialect_pymssql
|
|
@ -0,0 +1,745 @@
|
|||
# dialects/mssql/pyodbc.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mssql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mssql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: https://pypi.org/project/pyodbc/
|
||||
|
||||
Connecting to PyODBC
|
||||
--------------------
|
||||
|
||||
The URL here is to be translated to PyODBC connection strings, as
|
||||
detailed in `ConnectionStrings <https://code.google.com/p/pyodbc/wiki/ConnectionStrings>`_.
|
||||
|
||||
DSN Connections
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
A DSN connection in ODBC means that a pre-existing ODBC datasource is
|
||||
configured on the client machine. The application then specifies the name
|
||||
of this datasource, which encompasses details such as the specific ODBC driver
|
||||
in use as well as the network address of the database. Assuming a datasource
|
||||
is configured on the client, a basic DSN-based connection looks like::
|
||||
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@some_dsn")
|
||||
|
||||
Which above, will pass the following connection string to PyODBC::
|
||||
|
||||
DSN=some_dsn;UID=scott;PWD=tiger
|
||||
|
||||
If the username and password are omitted, the DSN form will also add
|
||||
the ``Trusted_Connection=yes`` directive to the ODBC string.
|
||||
|
||||
Hostname Connections
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Hostname-based connections are also supported by pyodbc. These are often
|
||||
easier to use than a DSN and have the additional advantage that the specific
|
||||
database name to connect towards may be specified locally in the URL, rather
|
||||
than it being fixed as part of a datasource configuration.
|
||||
|
||||
When using a hostname connection, the driver name must also be specified in the
|
||||
query parameters of the URL. As these names usually have spaces in them, the
|
||||
name must be URL encoded which means using plus signs for spaces::
|
||||
|
||||
engine = create_engine("mssql+pyodbc://scott:tiger@myhost:port/databasename?driver=ODBC+Driver+17+for+SQL+Server")
|
||||
|
||||
The ``driver`` keyword is significant to the pyodbc dialect and must be
|
||||
specified in lowercase.
|
||||
|
||||
Any other names passed in the query string are passed through in the pyodbc
|
||||
connect string, such as ``authentication``, ``TrustServerCertificate``, etc.
|
||||
Multiple keyword arguments must be separated by an ampersand (``&``); these
|
||||
will be translated to semicolons when the pyodbc connect string is generated
|
||||
internally::
|
||||
|
||||
e = create_engine(
|
||||
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?"
|
||||
"driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||
"&authentication=ActiveDirectoryIntegrated"
|
||||
)
|
||||
|
||||
The equivalent URL can be constructed using :class:`_sa.engine.URL`::
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
connection_url = URL.create(
|
||||
"mssql+pyodbc",
|
||||
username="scott",
|
||||
password="tiger",
|
||||
host="mssql2017",
|
||||
port=1433,
|
||||
database="test",
|
||||
query={
|
||||
"driver": "ODBC Driver 18 for SQL Server",
|
||||
"TrustServerCertificate": "yes",
|
||||
"authentication": "ActiveDirectoryIntegrated",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Pass through exact Pyodbc string
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
A PyODBC connection string can also be sent in pyodbc's format directly, as
|
||||
specified in `the PyODBC documentation
|
||||
<https://github.com/mkleehammer/pyodbc/wiki/Connecting-to-databases>`_,
|
||||
using the parameter ``odbc_connect``. A :class:`_sa.engine.URL` object
|
||||
can help make this easier::
|
||||
|
||||
from sqlalchemy.engine import URL
|
||||
connection_string = "DRIVER={SQL Server Native Client 10.0};SERVER=dagger;DATABASE=test;UID=user;PWD=password"
|
||||
connection_url = URL.create("mssql+pyodbc", query={"odbc_connect": connection_string})
|
||||
|
||||
engine = create_engine(connection_url)
|
||||
|
||||
.. _mssql_pyodbc_access_tokens:
|
||||
|
||||
Connecting to databases with access tokens
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Some database servers are set up to only accept access tokens for login. For
|
||||
example, SQL Server allows the use of Azure Active Directory tokens to connect
|
||||
to databases. This requires creating a credential object using the
|
||||
``azure-identity`` library. More information about the authentication step can be
|
||||
found in `Microsoft's documentation
|
||||
<https://docs.microsoft.com/en-us/azure/developer/python/azure-sdk-authenticate?tabs=bash>`_.
|
||||
|
||||
After getting an engine, the credentials need to be sent to ``pyodbc.connect``
|
||||
each time a connection is requested. One way to do this is to set up an event
|
||||
listener on the engine that adds the credential token to the dialect's connect
|
||||
call. This is discussed more generally in :ref:`engines_dynamic_tokens`. For
|
||||
SQL Server in particular, this is passed as an ODBC connection attribute with
|
||||
a data structure `described by Microsoft
|
||||
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_.
|
||||
|
||||
The following code snippet will create an engine that connects to an Azure SQL
|
||||
database using Azure credentials::
|
||||
|
||||
import struct
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.engine.url import URL
|
||||
from azure import identity
|
||||
|
||||
SQL_COPT_SS_ACCESS_TOKEN = 1256 # Connection option for access tokens, as defined in msodbcsql.h
|
||||
TOKEN_URL = "https://database.windows.net/" # The token URL for any Azure SQL database
|
||||
|
||||
connection_string = "mssql+pyodbc://@my-server.database.windows.net/myDb?driver=ODBC+Driver+17+for+SQL+Server"
|
||||
|
||||
engine = create_engine(connection_string)
|
||||
|
||||
azure_credentials = identity.DefaultAzureCredential()
|
||||
|
||||
@event.listens_for(engine, "do_connect")
|
||||
def provide_token(dialect, conn_rec, cargs, cparams):
|
||||
# remove the "Trusted_Connection" parameter that SQLAlchemy adds
|
||||
cargs[0] = cargs[0].replace(";Trusted_Connection=Yes", "")
|
||||
|
||||
# create token credential
|
||||
raw_token = azure_credentials.get_token(TOKEN_URL).token.encode("utf-16-le")
|
||||
token_struct = struct.pack(f"<I{len(raw_token)}s", len(raw_token), raw_token)
|
||||
|
||||
# apply it to keyword arguments
|
||||
cparams["attrs_before"] = {SQL_COPT_SS_ACCESS_TOKEN: token_struct}
|
||||
|
||||
.. tip::
|
||||
|
||||
The ``Trusted_Connection`` token is currently added by the SQLAlchemy
|
||||
pyodbc dialect when no username or password is present. This needs
|
||||
to be removed per Microsoft's
|
||||
`documentation for Azure access tokens
|
||||
<https://docs.microsoft.com/en-us/sql/connect/odbc/using-azure-active-directory#authenticating-with-an-access-token>`_,
|
||||
stating that a connection string when using an access token must not contain
|
||||
``UID``, ``PWD``, ``Authentication`` or ``Trusted_Connection`` parameters.
|
||||
|
||||
.. _azure_synapse_ignore_no_transaction_on_rollback:
|
||||
|
||||
Avoiding transaction-related exceptions on Azure Synapse Analytics
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Azure Synapse Analytics has a significant difference in its transaction
|
||||
handling compared to plain SQL Server; in some cases an error within a Synapse
|
||||
transaction can cause it to be arbitrarily terminated on the server side, which
|
||||
then causes the DBAPI ``.rollback()`` method (as well as ``.commit()``) to
|
||||
fail. The issue prevents the usual DBAPI contract of allowing ``.rollback()``
|
||||
to pass silently if no transaction is present as the driver does not expect
|
||||
this condition. The symptom of this failure is an exception with a message
|
||||
resembling 'No corresponding transaction found. (111214)' when attempting to
|
||||
emit a ``.rollback()`` after an operation had a failure of some kind.
|
||||
|
||||
This specific case can be handled by passing ``ignore_no_transaction_on_rollback=True`` to
|
||||
the SQL Server dialect via the :func:`_sa.create_engine` function as follows::
|
||||
|
||||
engine = create_engine(connection_url, ignore_no_transaction_on_rollback=True)
|
||||
|
||||
Using the above parameter, the dialect will catch ``ProgrammingError``
|
||||
exceptions raised during ``connection.rollback()`` and emit a warning
|
||||
if the error message contains code ``111214``, however will not raise
|
||||
an exception.
|
||||
|
||||
.. versionadded:: 1.4.40 Added the
|
||||
``ignore_no_transaction_on_rollback=True`` parameter.
|
||||
|
||||
Enable autocommit for Azure SQL Data Warehouse (DW) connections
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Azure SQL Data Warehouse does not support transactions,
|
||||
and that can cause problems with SQLAlchemy's "autobegin" (and implicit
|
||||
commit/rollback) behavior. We can avoid these problems by enabling autocommit
|
||||
at both the pyodbc and engine levels::
|
||||
|
||||
connection_url = sa.engine.URL.create(
|
||||
"mssql+pyodbc",
|
||||
username="scott",
|
||||
password="tiger",
|
||||
host="dw.azure.example.com",
|
||||
database="mydb",
|
||||
query={
|
||||
"driver": "ODBC Driver 17 for SQL Server",
|
||||
"autocommit": "True",
|
||||
},
|
||||
)
|
||||
|
||||
engine = create_engine(connection_url).execution_options(
|
||||
isolation_level="AUTOCOMMIT"
|
||||
)
|
||||
|
||||
Avoiding sending large string parameters as TEXT/NTEXT
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
By default, for historical reasons, Microsoft's ODBC drivers for SQL Server
|
||||
send long string parameters (greater than 4000 SBCS characters or 2000 Unicode
|
||||
characters) as TEXT/NTEXT values. TEXT and NTEXT have been deprecated for many
|
||||
years and are starting to cause compatibility issues with newer versions of
|
||||
SQL_Server/Azure. For example, see `this
|
||||
issue <https://github.com/mkleehammer/pyodbc/issues/835>`_.
|
||||
|
||||
Starting with ODBC Driver 18 for SQL Server we can override the legacy
|
||||
behavior and pass long strings as varchar(max)/nvarchar(max) using the
|
||||
``LongAsMax=Yes`` connection string parameter::
|
||||
|
||||
connection_url = sa.engine.URL.create(
|
||||
"mssql+pyodbc",
|
||||
username="scott",
|
||||
password="tiger",
|
||||
host="mssqlserver.example.com",
|
||||
database="mydb",
|
||||
query={
|
||||
"driver": "ODBC Driver 18 for SQL Server",
|
||||
"LongAsMax": "Yes",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
Pyodbc Pooling / connection close behavior
|
||||
------------------------------------------
|
||||
|
||||
PyODBC uses internal `pooling
|
||||
<https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ by
|
||||
default, which means connections will be longer lived than they are within
|
||||
SQLAlchemy itself. As SQLAlchemy has its own pooling behavior, it is often
|
||||
preferable to disable this behavior. This behavior can only be disabled
|
||||
globally at the PyODBC module level, **before** any connections are made::
|
||||
|
||||
import pyodbc
|
||||
|
||||
pyodbc.pooling = False
|
||||
|
||||
# don't use the engine before pooling is set to False
|
||||
engine = create_engine("mssql+pyodbc://user:pass@dsn")
|
||||
|
||||
If this variable is left at its default value of ``True``, **the application
|
||||
will continue to maintain active database connections**, even when the
|
||||
SQLAlchemy engine itself fully discards a connection or if the engine is
|
||||
disposed.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`pooling <https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#pooling>`_ -
|
||||
in the PyODBC documentation.
|
||||
|
||||
Driver / Unicode Support
|
||||
-------------------------
|
||||
|
||||
PyODBC works best with Microsoft ODBC drivers, particularly in the area
|
||||
of Unicode support on both Python 2 and Python 3.
|
||||
|
||||
Using the FreeTDS ODBC drivers on Linux or OSX with PyODBC is **not**
|
||||
recommended; there have been historically many Unicode-related issues
|
||||
in this area, including before Microsoft offered ODBC drivers for Linux
|
||||
and OSX. Now that Microsoft offers drivers for all platforms, for
|
||||
PyODBC support these are recommended. FreeTDS remains relevant for
|
||||
non-ODBC drivers such as pymssql where it works very well.
|
||||
|
||||
|
||||
Rowcount Support
|
||||
----------------
|
||||
|
||||
Previous limitations with the SQLAlchemy ORM's "versioned rows" feature with
|
||||
Pyodbc have been resolved as of SQLAlchemy 2.0.5. See the notes at
|
||||
:ref:`mssql_rowcount_versioning`.
|
||||
|
||||
.. _mssql_pyodbc_fastexecutemany:
|
||||
|
||||
Fast Executemany Mode
|
||||
---------------------
|
||||
|
||||
The PyODBC driver includes support for a "fast executemany" mode of execution
|
||||
which greatly reduces round trips for a DBAPI ``executemany()`` call when using
|
||||
Microsoft ODBC drivers, for **limited size batches that fit in memory**. The
|
||||
feature is enabled by setting the attribute ``.fast_executemany`` on the DBAPI
|
||||
cursor when an executemany call is to be used. The SQLAlchemy PyODBC SQL
|
||||
Server dialect supports this parameter by passing the
|
||||
``fast_executemany`` parameter to
|
||||
:func:`_sa.create_engine` , when using the **Microsoft ODBC driver only**::
|
||||
|
||||
engine = create_engine(
|
||||
"mssql+pyodbc://scott:tiger@mssql2017:1433/test?driver=ODBC+Driver+17+for+SQL+Server",
|
||||
fast_executemany=True)
|
||||
|
||||
.. versionchanged:: 2.0.9 - the ``fast_executemany`` parameter now has its
|
||||
intended effect of this PyODBC feature taking effect for all INSERT
|
||||
statements that are executed with multiple parameter sets, which don't
|
||||
include RETURNING. Previously, SQLAlchemy 2.0's :term:`insertmanyvalues`
|
||||
feature would cause ``fast_executemany`` to not be used in most cases
|
||||
even if specified.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
.. seealso::
|
||||
|
||||
`fast executemany <https://github.com/mkleehammer/pyodbc/wiki/Features-beyond-the-DB-API#fast_executemany>`_
|
||||
- on github
|
||||
|
||||
.. _mssql_pyodbc_setinputsizes:
|
||||
|
||||
Setinputsizes Support
|
||||
-----------------------
|
||||
|
||||
As of version 2.0, the pyodbc ``cursor.setinputsizes()`` method is used for
|
||||
all statement executions, except for ``cursor.executemany()`` calls when
|
||||
fast_executemany=True where it is not supported (assuming
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>` is kept enabled,
|
||||
"fastexecutemany" will not take place for INSERT statements in any case).
|
||||
|
||||
The use of ``cursor.setinputsizes()`` can be disabled by passing
|
||||
``use_setinputsizes=False`` to :func:`_sa.create_engine`.
|
||||
|
||||
When ``use_setinputsizes`` is left at its default of ``True``, the
|
||||
specific per-type symbols passed to ``cursor.setinputsizes()`` can be
|
||||
programmatically customized using the :meth:`.DialectEvents.do_setinputsizes`
|
||||
hook. See that method for usage examples.
|
||||
|
||||
.. versionchanged:: 2.0 The mssql+pyodbc dialect now defaults to using
|
||||
``use_setinputsizes=True`` for all statement executions with the exception of
|
||||
cursor.executemany() calls when fast_executemany=True. The behavior can
|
||||
be turned off by passing ``use_setinputsizes=False`` to
|
||||
:func:`_sa.create_engine`.
|
||||
|
||||
""" # noqa
|
||||
|
||||
|
||||
import datetime
|
||||
import decimal
|
||||
import re
|
||||
import struct
|
||||
|
||||
from .base import _MSDateTime
|
||||
from .base import _MSUnicode
|
||||
from .base import _MSUnicodeText
|
||||
from .base import BINARY
|
||||
from .base import DATETIMEOFFSET
|
||||
from .base import MSDialect
|
||||
from .base import MSExecutionContext
|
||||
from .base import VARBINARY
|
||||
from .json import JSON as _MSJson
|
||||
from .json import JSONIndexType as _MSJsonIndexType
|
||||
from .json import JSONPathType as _MSJsonPathType
|
||||
from ... import exc
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ...engine import cursor as _cursor
|
||||
|
||||
|
||||
class _ms_numeric_pyodbc:
|
||||
"""Turns Decimals with adjusted() < 0 or > 7 into strings.
|
||||
|
||||
The routines here are needed for older pyodbc versions
|
||||
as well as current mxODBC versions.
|
||||
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_process = super().bind_processor(dialect)
|
||||
|
||||
if not dialect._need_decimal_fix:
|
||||
return super_process
|
||||
|
||||
def process(value):
|
||||
if self.asdecimal and isinstance(value, decimal.Decimal):
|
||||
adjusted = value.adjusted()
|
||||
if adjusted < 0:
|
||||
return self._small_dec_to_string(value)
|
||||
elif adjusted > 7:
|
||||
return self._large_dec_to_string(value)
|
||||
|
||||
if super_process:
|
||||
return super_process(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
# these routines needed for older versions of pyodbc.
|
||||
# as of 2.1.8 this logic is integrated.
|
||||
|
||||
def _small_dec_to_string(self, value):
|
||||
return "%s0.%s%s" % (
|
||||
(value < 0 and "-" or ""),
|
||||
"0" * (abs(value.adjusted()) - 1),
|
||||
"".join([str(nint) for nint in value.as_tuple()[1]]),
|
||||
)
|
||||
|
||||
def _large_dec_to_string(self, value):
|
||||
_int = value.as_tuple()[1]
|
||||
if "E" in str(value):
|
||||
result = "%s%s%s" % (
|
||||
(value < 0 and "-" or ""),
|
||||
"".join([str(s) for s in _int]),
|
||||
"0" * (value.adjusted() - (len(_int) - 1)),
|
||||
)
|
||||
else:
|
||||
if (len(_int) - 1) > value.adjusted():
|
||||
result = "%s%s.%s" % (
|
||||
(value < 0 and "-" or ""),
|
||||
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
|
||||
"".join([str(s) for s in _int][value.adjusted() + 1 :]),
|
||||
)
|
||||
else:
|
||||
result = "%s%s" % (
|
||||
(value < 0 and "-" or ""),
|
||||
"".join([str(s) for s in _int][0 : value.adjusted() + 1]),
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
class _MSNumeric_pyodbc(_ms_numeric_pyodbc, sqltypes.Numeric):
|
||||
pass
|
||||
|
||||
|
||||
class _MSFloat_pyodbc(_ms_numeric_pyodbc, sqltypes.Float):
|
||||
pass
|
||||
|
||||
|
||||
class _ms_binary_pyodbc:
|
||||
"""Wraps binary values in dialect-specific Binary wrapper.
|
||||
If the value is null, return a pyodbc-specific BinaryNull
|
||||
object to prevent pyODBC [and FreeTDS] from defaulting binary
|
||||
NULL types to SQLWCHAR and causing implicit conversion errors.
|
||||
"""
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.dbapi is None:
|
||||
return None
|
||||
|
||||
DBAPIBinary = dialect.dbapi.Binary
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return DBAPIBinary(value)
|
||||
else:
|
||||
# pyodbc-specific
|
||||
return dialect.dbapi.BinaryNull
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class _ODBCDateTimeBindProcessor:
|
||||
"""Add bind processors to handle datetimeoffset behaviors"""
|
||||
|
||||
has_tz = False
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, str):
|
||||
# if a string was passed directly, allow it through
|
||||
return value
|
||||
elif not value.tzinfo or (not self.timezone and not self.has_tz):
|
||||
# for DateTime(timezone=False)
|
||||
return value
|
||||
else:
|
||||
# for DATETIMEOFFSET or DateTime(timezone=True)
|
||||
#
|
||||
# Convert to string format required by T-SQL
|
||||
dto_string = value.strftime("%Y-%m-%d %H:%M:%S.%f %z")
|
||||
# offset needs a colon, e.g., -0700 -> -07:00
|
||||
# "UTC offset in the form (+-)HHMM[SS[.ffffff]]"
|
||||
# backend currently rejects seconds / fractional seconds
|
||||
dto_string = re.sub(
|
||||
r"([\+\-]\d{2})([\d\.]+)$", r"\1:\2", dto_string
|
||||
)
|
||||
return dto_string
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class _ODBCDateTime(_ODBCDateTimeBindProcessor, _MSDateTime):
|
||||
pass
|
||||
|
||||
|
||||
class _ODBCDATETIMEOFFSET(_ODBCDateTimeBindProcessor, DATETIMEOFFSET):
|
||||
has_tz = True
|
||||
|
||||
|
||||
class _VARBINARY_pyodbc(_ms_binary_pyodbc, VARBINARY):
|
||||
pass
|
||||
|
||||
|
||||
class _BINARY_pyodbc(_ms_binary_pyodbc, BINARY):
|
||||
pass
|
||||
|
||||
|
||||
class _String_pyodbc(sqltypes.String):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
if self.length in (None, "max") or self.length >= 2000:
|
||||
return (dbapi.SQL_VARCHAR, 0, 0)
|
||||
else:
|
||||
return dbapi.SQL_VARCHAR
|
||||
|
||||
|
||||
class _Unicode_pyodbc(_MSUnicode):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
if self.length in (None, "max") or self.length >= 2000:
|
||||
return (dbapi.SQL_WVARCHAR, 0, 0)
|
||||
else:
|
||||
return dbapi.SQL_WVARCHAR
|
||||
|
||||
|
||||
class _UnicodeText_pyodbc(_MSUnicodeText):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
if self.length in (None, "max") or self.length >= 2000:
|
||||
return (dbapi.SQL_WVARCHAR, 0, 0)
|
||||
else:
|
||||
return dbapi.SQL_WVARCHAR
|
||||
|
||||
|
||||
class _JSON_pyodbc(_MSJson):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return (dbapi.SQL_WVARCHAR, 0, 0)
|
||||
|
||||
|
||||
class _JSONIndexType_pyodbc(_MSJsonIndexType):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.SQL_WVARCHAR
|
||||
|
||||
|
||||
class _JSONPathType_pyodbc(_MSJsonPathType):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.SQL_WVARCHAR
|
||||
|
||||
|
||||
class MSExecutionContext_pyodbc(MSExecutionContext):
|
||||
_embedded_scope_identity = False
|
||||
|
||||
def pre_exec(self):
|
||||
"""where appropriate, issue "select scope_identity()" in the same
|
||||
statement.
|
||||
|
||||
Background on why "scope_identity()" is preferable to "@@identity":
|
||||
https://msdn.microsoft.com/en-us/library/ms190315.aspx
|
||||
|
||||
Background on why we attempt to embed "scope_identity()" into the same
|
||||
statement as the INSERT:
|
||||
https://code.google.com/p/pyodbc/wiki/FAQs#How_do_I_retrieve_autogenerated/identity_values?
|
||||
|
||||
"""
|
||||
|
||||
super().pre_exec()
|
||||
|
||||
# don't embed the scope_identity select into an
|
||||
# "INSERT .. DEFAULT VALUES"
|
||||
if (
|
||||
self._select_lastrowid
|
||||
and self.dialect.use_scope_identity
|
||||
and len(self.parameters[0])
|
||||
):
|
||||
self._embedded_scope_identity = True
|
||||
|
||||
self.statement += "; select scope_identity()"
|
||||
|
||||
def post_exec(self):
|
||||
if self._embedded_scope_identity:
|
||||
# Fetch the last inserted id from the manipulated statement
|
||||
# We may have to skip over a number of result sets with
|
||||
# no data (due to triggers, etc.)
|
||||
while True:
|
||||
try:
|
||||
# fetchall() ensures the cursor is consumed
|
||||
# without closing it (FreeTDS particularly)
|
||||
rows = self.cursor.fetchall()
|
||||
except self.dialect.dbapi.Error:
|
||||
# no way around this - nextset() consumes the previous set
|
||||
# so we need to just keep flipping
|
||||
self.cursor.nextset()
|
||||
else:
|
||||
if not rows:
|
||||
# async adapter drivers just return None here
|
||||
self.cursor.nextset()
|
||||
continue
|
||||
row = rows[0]
|
||||
break
|
||||
|
||||
self._lastrowid = int(row[0])
|
||||
|
||||
self.cursor_fetch_strategy = _cursor._NO_CURSOR_DML
|
||||
else:
|
||||
super().post_exec()
|
||||
|
||||
|
||||
class MSDialect_pyodbc(PyODBCConnector, MSDialect):
|
||||
supports_statement_cache = True
|
||||
|
||||
# note this parameter is no longer used by the ORM or default dialect
|
||||
# see #9414
|
||||
supports_sane_rowcount_returning = False
|
||||
|
||||
execution_ctx_cls = MSExecutionContext_pyodbc
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MSDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric: _MSNumeric_pyodbc,
|
||||
sqltypes.Float: _MSFloat_pyodbc,
|
||||
BINARY: _BINARY_pyodbc,
|
||||
# support DateTime(timezone=True)
|
||||
sqltypes.DateTime: _ODBCDateTime,
|
||||
DATETIMEOFFSET: _ODBCDATETIMEOFFSET,
|
||||
# SQL Server dialect has a VARBINARY that is just to support
|
||||
# "deprecate_large_types" w/ VARBINARY(max), but also we must
|
||||
# handle the usual SQL standard VARBINARY
|
||||
VARBINARY: _VARBINARY_pyodbc,
|
||||
sqltypes.VARBINARY: _VARBINARY_pyodbc,
|
||||
sqltypes.LargeBinary: _VARBINARY_pyodbc,
|
||||
sqltypes.String: _String_pyodbc,
|
||||
sqltypes.Unicode: _Unicode_pyodbc,
|
||||
sqltypes.UnicodeText: _UnicodeText_pyodbc,
|
||||
sqltypes.JSON: _JSON_pyodbc,
|
||||
sqltypes.JSON.JSONIndexType: _JSONIndexType_pyodbc,
|
||||
sqltypes.JSON.JSONPathType: _JSONPathType_pyodbc,
|
||||
# this excludes Enum from the string/VARCHAR thing for now
|
||||
# it looks like Enum's adaptation doesn't really support the
|
||||
# String type itself having a dialect-level impl
|
||||
sqltypes.Enum: sqltypes.Enum,
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fast_executemany=False,
|
||||
use_setinputsizes=True,
|
||||
**params,
|
||||
):
|
||||
super().__init__(use_setinputsizes=use_setinputsizes, **params)
|
||||
self.use_scope_identity = (
|
||||
self.use_scope_identity
|
||||
and self.dbapi
|
||||
and hasattr(self.dbapi.Cursor, "nextset")
|
||||
)
|
||||
self._need_decimal_fix = self.dbapi and self._dbapi_version() < (
|
||||
2,
|
||||
1,
|
||||
8,
|
||||
)
|
||||
self.fast_executemany = fast_executemany
|
||||
if fast_executemany:
|
||||
self.use_insertmanyvalues_wo_returning = False
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
try:
|
||||
# "Version of the instance of SQL Server, in the form
|
||||
# of 'major.minor.build.revision'"
|
||||
raw = connection.exec_driver_sql(
|
||||
"SELECT CAST(SERVERPROPERTY('ProductVersion') AS VARCHAR)"
|
||||
).scalar()
|
||||
except exc.DBAPIError:
|
||||
# SQL Server docs indicate this function isn't present prior to
|
||||
# 2008. Before we had the VARCHAR cast above, pyodbc would also
|
||||
# fail on this query.
|
||||
return super()._get_server_version_info(connection)
|
||||
else:
|
||||
version = []
|
||||
r = re.compile(r"[.\-]")
|
||||
for n in r.split(raw):
|
||||
try:
|
||||
version.append(int(n))
|
||||
except ValueError:
|
||||
pass
|
||||
return tuple(version)
|
||||
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
self._setup_timestampoffset_type(conn)
|
||||
|
||||
return on_connect
|
||||
|
||||
def _setup_timestampoffset_type(self, connection):
|
||||
# output converter function for datetimeoffset
|
||||
def _handle_datetimeoffset(dto_value):
|
||||
tup = struct.unpack("<6hI2h", dto_value)
|
||||
return datetime.datetime(
|
||||
tup[0],
|
||||
tup[1],
|
||||
tup[2],
|
||||
tup[3],
|
||||
tup[4],
|
||||
tup[5],
|
||||
tup[6] // 1000,
|
||||
datetime.timezone(
|
||||
datetime.timedelta(hours=tup[7], minutes=tup[8])
|
||||
),
|
||||
)
|
||||
|
||||
odbc_SQL_SS_TIMESTAMPOFFSET = -155 # as defined in SQLNCLI.h
|
||||
connection.add_output_converter(
|
||||
odbc_SQL_SS_TIMESTAMPOFFSET, _handle_datetimeoffset
|
||||
)
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
if self.fast_executemany:
|
||||
cursor.fast_executemany = True
|
||||
super().do_executemany(cursor, statement, parameters, context=context)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.Error):
|
||||
code = e.args[0]
|
||||
if code in {
|
||||
"08S01",
|
||||
"01000",
|
||||
"01002",
|
||||
"08003",
|
||||
"08007",
|
||||
"08S02",
|
||||
"08001",
|
||||
"HYT00",
|
||||
"HY010",
|
||||
"10054",
|
||||
}:
|
||||
return True
|
||||
return super().is_disconnect(e, connection, cursor)
|
||||
|
||||
|
||||
dialect = MSDialect_pyodbc
|
|
@ -0,0 +1,101 @@
|
|||
# dialects/mysql/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from . import aiomysql # noqa
|
||||
from . import asyncmy # noqa
|
||||
from . import base # noqa
|
||||
from . import cymysql # noqa
|
||||
from . import mariadbconnector # noqa
|
||||
from . import mysqlconnector # noqa
|
||||
from . import mysqldb # noqa
|
||||
from . import pymysql # noqa
|
||||
from . import pyodbc # noqa
|
||||
from .base import BIGINT
|
||||
from .base import BINARY
|
||||
from .base import BIT
|
||||
from .base import BLOB
|
||||
from .base import BOOLEAN
|
||||
from .base import CHAR
|
||||
from .base import DATE
|
||||
from .base import DATETIME
|
||||
from .base import DECIMAL
|
||||
from .base import DOUBLE
|
||||
from .base import ENUM
|
||||
from .base import FLOAT
|
||||
from .base import INTEGER
|
||||
from .base import JSON
|
||||
from .base import LONGBLOB
|
||||
from .base import LONGTEXT
|
||||
from .base import MEDIUMBLOB
|
||||
from .base import MEDIUMINT
|
||||
from .base import MEDIUMTEXT
|
||||
from .base import NCHAR
|
||||
from .base import NUMERIC
|
||||
from .base import NVARCHAR
|
||||
from .base import REAL
|
||||
from .base import SET
|
||||
from .base import SMALLINT
|
||||
from .base import TEXT
|
||||
from .base import TIME
|
||||
from .base import TIMESTAMP
|
||||
from .base import TINYBLOB
|
||||
from .base import TINYINT
|
||||
from .base import TINYTEXT
|
||||
from .base import VARBINARY
|
||||
from .base import VARCHAR
|
||||
from .base import YEAR
|
||||
from .dml import Insert
|
||||
from .dml import insert
|
||||
from .expression import match
|
||||
from ...util import compat
|
||||
|
||||
# default dialect
|
||||
base.dialect = dialect = mysqldb.dialect
|
||||
|
||||
__all__ = (
|
||||
"BIGINT",
|
||||
"BINARY",
|
||||
"BIT",
|
||||
"BLOB",
|
||||
"BOOLEAN",
|
||||
"CHAR",
|
||||
"DATE",
|
||||
"DATETIME",
|
||||
"DECIMAL",
|
||||
"DOUBLE",
|
||||
"ENUM",
|
||||
"FLOAT",
|
||||
"INTEGER",
|
||||
"INTEGER",
|
||||
"JSON",
|
||||
"LONGBLOB",
|
||||
"LONGTEXT",
|
||||
"MEDIUMBLOB",
|
||||
"MEDIUMINT",
|
||||
"MEDIUMTEXT",
|
||||
"NCHAR",
|
||||
"NVARCHAR",
|
||||
"NUMERIC",
|
||||
"SET",
|
||||
"SMALLINT",
|
||||
"REAL",
|
||||
"TEXT",
|
||||
"TIME",
|
||||
"TIMESTAMP",
|
||||
"TINYBLOB",
|
||||
"TINYINT",
|
||||
"TINYTEXT",
|
||||
"VARBINARY",
|
||||
"VARCHAR",
|
||||
"YEAR",
|
||||
"dialect",
|
||||
"insert",
|
||||
"Insert",
|
||||
"match",
|
||||
)
|
|
@ -0,0 +1,332 @@
|
|||
# dialects/mysql/aiomysql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+aiomysql
|
||||
:name: aiomysql
|
||||
:dbapi: aiomysql
|
||||
:connectstring: mysql+aiomysql://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://github.com/aio-libs/aiomysql
|
||||
|
||||
The aiomysql dialect is SQLAlchemy's second Python asyncio dialect.
|
||||
|
||||
Using a special asyncio mediation layer, the aiomysql dialect is usable
|
||||
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
|
||||
extension package.
|
||||
|
||||
This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+aiomysql://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
|
||||
""" # noqa
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.Cursor)
|
||||
|
||||
# see https://github.com/aio-libs/aiomysql/issues/543
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
if not self.server_side:
|
||||
# aiomysql has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._execute_mutex:
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_ss_cursor(AsyncAdapt_aiomysql_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor(adapt_connection.dbapi.SSCursor)
|
||||
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
def ping(self, reconnect):
|
||||
return self.await_(self._connection.ping(reconnect))
|
||||
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_aiomysql_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_aiomysql_cursor(self)
|
||||
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def terminate(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_aiomysql_connection(AsyncAdapt_aiomysql_connection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_aiomysql_dbapi:
|
||||
def __init__(self, aiomysql, pymysql):
|
||||
self.aiomysql = aiomysql
|
||||
self.pymysql = pymysql
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
self.Cursor, self.SSCursor = self._init_cursors_subclasses()
|
||||
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
"InterfaceError",
|
||||
"DataError",
|
||||
"DatabaseError",
|
||||
"OperationalError",
|
||||
"InterfaceError",
|
||||
"IntegrityError",
|
||||
"ProgrammingError",
|
||||
"InternalError",
|
||||
"NotSupportedError",
|
||||
):
|
||||
setattr(self, name, getattr(self.aiomysql, name))
|
||||
|
||||
for name in (
|
||||
"NUMBER",
|
||||
"STRING",
|
||||
"DATETIME",
|
||||
"BINARY",
|
||||
"TIMESTAMP",
|
||||
"Binary",
|
||||
):
|
||||
setattr(self, name, getattr(self.pymysql, name))
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.aiomysql.connect)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return AsyncAdaptFallback_aiomysql_connection(
|
||||
self,
|
||||
await_fallback(creator_fn(*arg, **kw)),
|
||||
)
|
||||
else:
|
||||
return AsyncAdapt_aiomysql_connection(
|
||||
self,
|
||||
await_only(creator_fn(*arg, **kw)),
|
||||
)
|
||||
|
||||
def _init_cursors_subclasses(self):
|
||||
# suppress unconditional warning emitted by aiomysql
|
||||
class Cursor(self.aiomysql.Cursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
class SSCursor(self.aiomysql.SSCursor):
|
||||
async def _show_warnings(self, conn):
|
||||
pass
|
||||
|
||||
return Cursor, SSCursor
|
||||
|
||||
|
||||
class MySQLDialect_aiomysql(MySQLDialect_pymysql):
|
||||
driver = "aiomysql"
|
||||
supports_statement_cache = True
|
||||
|
||||
supports_server_side_cursors = True
|
||||
_sscursor = AsyncAdapt_aiomysql_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_aiomysql_dbapi(
|
||||
__import__("aiomysql"), __import__("pymysql")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return pool.FallbackAsyncAdaptedQueuePool
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e
|
||||
|
||||
def _found_rows_client_flag(self):
|
||||
from pymysql.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_aiomysql
|
|
@ -0,0 +1,337 @@
|
|||
# dialects/mysql/asyncmy.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+asyncmy
|
||||
:name: asyncmy
|
||||
:dbapi: asyncmy
|
||||
:connectstring: mysql+asyncmy://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://github.com/long2ice/asyncmy
|
||||
|
||||
Using a special asyncio mediation layer, the asyncmy dialect is usable
|
||||
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
|
||||
extension package.
|
||||
|
||||
This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("mysql+asyncmy://user:pass@hostname/dbname?charset=utf8mb4")
|
||||
|
||||
|
||||
""" # noqa
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from .pymysql import MySQLDialect_pymysql
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import asyncio
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
server_side = False
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"await_",
|
||||
"_cursor",
|
||||
"_rows",
|
||||
)
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor()
|
||||
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
self._rows = []
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self._cursor.description
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self._cursor.rowcount
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
@property
|
||||
def lastrowid(self):
|
||||
return self._cursor.lastrowid
|
||||
|
||||
def close(self):
|
||||
# note we aren't actually closing the cursor here,
|
||||
# we are just letting GC do it. to allow this to be async
|
||||
# we would need the Result to change how it does "Safe close cursor".
|
||||
# MySQL "cursors" don't actually have state to be "closed" besides
|
||||
# exhausting rows, which we already have done for sync cursor.
|
||||
# another option would be to emulate aiosqlite dialect and assign
|
||||
# cursor only if we are doing server side cursor operation.
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
return self.await_(self._execute_async(operation, parameters))
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
return self.await_(
|
||||
self._executemany_async(operation, seq_of_parameters)
|
||||
)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
if parameters is None:
|
||||
result = await self._cursor.execute(operation)
|
||||
else:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
if not self.server_side:
|
||||
# asyncmy has a "fake" async result, so we have to pull it out
|
||||
# of that here since our default result is not async.
|
||||
# we could just as easily grab "_rows" here and be done with it
|
||||
# but this is safer.
|
||||
self._rows = list(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(self, operation, seq_of_parameters):
|
||||
async with self._adapt_connection._mutex_and_adapt_errors():
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_ss_cursor(AsyncAdapt_asyncmy_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = ()
|
||||
server_side = True
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
|
||||
cursor = self._connection.cursor(
|
||||
adapt_connection.dbapi.asyncmy.cursors.SSCursor
|
||||
)
|
||||
|
||||
self._cursor = self.await_(cursor.__aenter__())
|
||||
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_connection(AdaptedConnection):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi", "_execute_mutex")
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
self._execute_mutex = asyncio.Lock()
|
||||
|
||||
@asynccontextmanager
|
||||
async def _mutex_and_adapt_errors(self):
|
||||
async with self._execute_mutex:
|
||||
try:
|
||||
yield
|
||||
except AttributeError:
|
||||
raise self.dbapi.InternalError(
|
||||
"network operation failed due to asyncmy attribute error"
|
||||
)
|
||||
|
||||
def ping(self, reconnect):
|
||||
assert not reconnect
|
||||
return self.await_(self._do_ping())
|
||||
|
||||
async def _do_ping(self):
|
||||
async with self._mutex_and_adapt_errors():
|
||||
return await self._connection.ping(False)
|
||||
|
||||
def character_set_name(self):
|
||||
return self._connection.character_set_name()
|
||||
|
||||
def autocommit(self, value):
|
||||
self.await_(self._connection.autocommit(value))
|
||||
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_asyncmy_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_asyncmy_cursor(self)
|
||||
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def terminate(self):
|
||||
# it's not awaitable.
|
||||
self._connection.close()
|
||||
|
||||
def close(self) -> None:
|
||||
self.await_(self._connection.ensure_closed())
|
||||
|
||||
|
||||
class AsyncAdaptFallback_asyncmy_connection(AsyncAdapt_asyncmy_connection):
|
||||
__slots__ = ()
|
||||
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
def _Binary(x):
|
||||
"""Return x as a binary type."""
|
||||
return bytes(x)
|
||||
|
||||
|
||||
class AsyncAdapt_asyncmy_dbapi:
|
||||
def __init__(self, asyncmy):
|
||||
self.asyncmy = asyncmy
|
||||
self.paramstyle = "format"
|
||||
self._init_dbapi_attributes()
|
||||
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"Warning",
|
||||
"Error",
|
||||
"InterfaceError",
|
||||
"DataError",
|
||||
"DatabaseError",
|
||||
"OperationalError",
|
||||
"InterfaceError",
|
||||
"IntegrityError",
|
||||
"ProgrammingError",
|
||||
"InternalError",
|
||||
"NotSupportedError",
|
||||
):
|
||||
setattr(self, name, getattr(self.asyncmy.errors, name))
|
||||
|
||||
STRING = util.symbol("STRING")
|
||||
NUMBER = util.symbol("NUMBER")
|
||||
BINARY = util.symbol("BINARY")
|
||||
DATETIME = util.symbol("DATETIME")
|
||||
TIMESTAMP = util.symbol("TIMESTAMP")
|
||||
Binary = staticmethod(_Binary)
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.asyncmy.connect)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return AsyncAdaptFallback_asyncmy_connection(
|
||||
self,
|
||||
await_fallback(creator_fn(*arg, **kw)),
|
||||
)
|
||||
else:
|
||||
return AsyncAdapt_asyncmy_connection(
|
||||
self,
|
||||
await_only(creator_fn(*arg, **kw)),
|
||||
)
|
||||
|
||||
|
||||
class MySQLDialect_asyncmy(MySQLDialect_pymysql):
|
||||
driver = "asyncmy"
|
||||
supports_statement_cache = True
|
||||
|
||||
supports_server_side_cursors = True
|
||||
_sscursor = AsyncAdapt_asyncmy_ss_cursor
|
||||
|
||||
is_async = True
|
||||
has_terminate = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_asyncmy_dbapi(__import__("asyncmy"))
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return pool.FallbackAsyncAdaptedQueuePool
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def do_terminate(self, dbapi_connection) -> None:
|
||||
dbapi_connection.terminate()
|
||||
|
||||
def create_connect_args(self, url):
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=dict(username="user", database="db")
|
||||
)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
else:
|
||||
str_e = str(e).lower()
|
||||
return (
|
||||
"not connected" in str_e or "network operation failed" in str_e
|
||||
)
|
||||
|
||||
def _found_rows_client_flag(self):
|
||||
from asyncmy.constants import CLIENT
|
||||
|
||||
return CLIENT.FOUND_ROWS
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = MySQLDialect_asyncmy
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,84 @@
|
|||
# dialects/mysql/cymysql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
|
||||
.. dialect:: mysql+cymysql
|
||||
:name: CyMySQL
|
||||
:dbapi: cymysql
|
||||
:connectstring: mysql+cymysql://<username>:<password>@<host>/<dbname>[?<options>]
|
||||
:url: https://github.com/nakagami/CyMySQL
|
||||
|
||||
.. note::
|
||||
|
||||
The CyMySQL dialect is **not tested as part of SQLAlchemy's continuous
|
||||
integration** and may have unresolved issues. The recommended MySQL
|
||||
dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
|
||||
from .base import BIT
|
||||
from .base import MySQLDialect
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ... import util
|
||||
|
||||
|
||||
class _cymysqlBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert MySQL's 64 bit, variable length binary string to a long."""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in iter(value):
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class MySQLDialect_cymysql(MySQLDialect_mysqldb):
|
||||
driver = "cymysql"
|
||||
supports_statement_cache = True
|
||||
|
||||
description_encoding = None
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = False
|
||||
supports_unicode_statements = True
|
||||
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _cymysqlBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("cymysql")
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.OperationalError):
|
||||
return self._extract_error_code(e) in (
|
||||
2006,
|
||||
2013,
|
||||
2014,
|
||||
2045,
|
||||
2055,
|
||||
)
|
||||
elif isinstance(e, self.dbapi.InterfaceError):
|
||||
# if underlying connection is closed,
|
||||
# this is the error you get
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
dialect = MySQLDialect_cymysql
|
|
@ -0,0 +1,219 @@
|
|||
# dialects/mysql/dml.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import List
|
||||
from typing import Mapping
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...sql._typing import _DMLTableArgument
|
||||
from ...sql.base import _exclusive_against
|
||||
from ...sql.base import _generative
|
||||
from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.expression import alias
|
||||
from ...sql.selectable import NamedFromClause
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
__all__ = ("Insert", "insert")
|
||||
|
||||
|
||||
def insert(table: _DMLTableArgument) -> Insert:
|
||||
"""Construct a MySQL/MariaDB-specific variant :class:`_mysql.Insert`
|
||||
construct.
|
||||
|
||||
.. container:: inherited_member
|
||||
|
||||
The :func:`sqlalchemy.dialects.mysql.insert` function creates
|
||||
a :class:`sqlalchemy.dialects.mysql.Insert`. This class is based
|
||||
on the dialect-agnostic :class:`_sql.Insert` construct which may
|
||||
be constructed using the :func:`_sql.insert` function in
|
||||
SQLAlchemy Core.
|
||||
|
||||
The :class:`_mysql.Insert` construct includes additional methods
|
||||
:meth:`_mysql.Insert.on_duplicate_key_update`.
|
||||
|
||||
"""
|
||||
return Insert(table)
|
||||
|
||||
|
||||
class Insert(StandardInsert):
|
||||
"""MySQL-specific implementation of INSERT.
|
||||
|
||||
Adds methods for MySQL-specific syntaxes such as ON DUPLICATE KEY UPDATE.
|
||||
|
||||
The :class:`~.mysql.Insert` object is created using the
|
||||
:func:`sqlalchemy.dialects.mysql.insert` function.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
"""
|
||||
|
||||
stringify_dialect = "mysql"
|
||||
inherit_cache = False
|
||||
|
||||
@property
|
||||
def inserted(
|
||||
self,
|
||||
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
|
||||
"""Provide the "inserted" namespace for an ON DUPLICATE KEY UPDATE
|
||||
statement
|
||||
|
||||
MySQL's ON DUPLICATE KEY UPDATE clause allows reference to the row
|
||||
that would be inserted, via a special function called ``VALUES()``.
|
||||
This attribute provides all columns in this row to be referenceable
|
||||
such that they will render within a ``VALUES()`` function inside the
|
||||
ON DUPLICATE KEY UPDATE clause. The attribute is named ``.inserted``
|
||||
so as not to conflict with the existing
|
||||
:meth:`_expression.Insert.values` method.
|
||||
|
||||
.. tip:: The :attr:`_mysql.Insert.inserted` attribute is an instance
|
||||
of :class:`_expression.ColumnCollection`, which provides an
|
||||
interface the same as that of the :attr:`_schema.Table.c`
|
||||
collection described at :ref:`metadata_tables_and_columns`.
|
||||
With this collection, ordinary names are accessible like attributes
|
||||
(e.g. ``stmt.inserted.some_column``), but special names and
|
||||
dictionary method names should be accessed using indexed access,
|
||||
such as ``stmt.inserted["column name"]`` or
|
||||
``stmt.inserted["values"]``. See the docstring for
|
||||
:class:`_expression.ColumnCollection` for further examples.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`mysql_insert_on_duplicate_key_update` - example of how
|
||||
to use :attr:`_expression.Insert.inserted`
|
||||
|
||||
"""
|
||||
return self.inserted_alias.columns
|
||||
|
||||
@util.memoized_property
|
||||
def inserted_alias(self) -> NamedFromClause:
|
||||
return alias(self.table, name="inserted")
|
||||
|
||||
@_generative
|
||||
@_exclusive_against(
|
||||
"_post_values_clause",
|
||||
msgs={
|
||||
"_post_values_clause": "This Insert construct already "
|
||||
"has an ON DUPLICATE KEY clause present"
|
||||
},
|
||||
)
|
||||
def on_duplicate_key_update(self, *args: _UpdateArg, **kw: Any) -> Self:
|
||||
r"""
|
||||
Specifies the ON DUPLICATE KEY UPDATE clause.
|
||||
|
||||
:param \**kw: Column keys linked to UPDATE values. The
|
||||
values may be any SQL expression or supported literal Python
|
||||
values.
|
||||
|
||||
.. warning:: This dictionary does **not** take into account
|
||||
Python-specified default UPDATE values or generation functions,
|
||||
e.g. those specified using :paramref:`_schema.Column.onupdate`.
|
||||
These values will not be exercised for an ON DUPLICATE KEY UPDATE
|
||||
style of UPDATE, unless values are manually specified here.
|
||||
|
||||
:param \*args: As an alternative to passing key/value parameters,
|
||||
a dictionary or list of 2-tuples can be passed as a single positional
|
||||
argument.
|
||||
|
||||
Passing a single dictionary is equivalent to the keyword argument
|
||||
form::
|
||||
|
||||
insert().on_duplicate_key_update({"name": "some name"})
|
||||
|
||||
Passing a list of 2-tuples indicates that the parameter assignments
|
||||
in the UPDATE clause should be ordered as sent, in a manner similar
|
||||
to that described for the :class:`_expression.Update`
|
||||
construct overall
|
||||
in :ref:`tutorial_parameter_ordered_updates`::
|
||||
|
||||
insert().on_duplicate_key_update(
|
||||
[("name", "some name"), ("value", "some value")])
|
||||
|
||||
.. versionchanged:: 1.3 parameters can be specified as a dictionary
|
||||
or list of 2-tuples; the latter form provides for parameter
|
||||
ordering.
|
||||
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`mysql_insert_on_duplicate_key_update`
|
||||
|
||||
"""
|
||||
if args and kw:
|
||||
raise exc.ArgumentError(
|
||||
"Can't pass kwargs and positional arguments simultaneously"
|
||||
)
|
||||
|
||||
if args:
|
||||
if len(args) > 1:
|
||||
raise exc.ArgumentError(
|
||||
"Only a single dictionary or list of tuples "
|
||||
"is accepted positionally."
|
||||
)
|
||||
values = args[0]
|
||||
else:
|
||||
values = kw
|
||||
|
||||
self._post_values_clause = OnDuplicateClause(
|
||||
self.inserted_alias, values
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class OnDuplicateClause(ClauseElement):
|
||||
__visit_name__ = "on_duplicate_key_update"
|
||||
|
||||
_parameter_ordering: Optional[List[str]] = None
|
||||
|
||||
stringify_dialect = "mysql"
|
||||
|
||||
def __init__(
|
||||
self, inserted_alias: NamedFromClause, update: _UpdateArg
|
||||
) -> None:
|
||||
self.inserted_alias = inserted_alias
|
||||
|
||||
# auto-detect that parameters should be ordered. This is copied from
|
||||
# Update._proces_colparams(), however we don't look for a special flag
|
||||
# in this case since we are not disambiguating from other use cases as
|
||||
# we are in Update.values().
|
||||
if isinstance(update, list) and (
|
||||
update and isinstance(update[0], tuple)
|
||||
):
|
||||
self._parameter_ordering = [key for key, value in update]
|
||||
update = dict(update)
|
||||
|
||||
if isinstance(update, dict):
|
||||
if not update:
|
||||
raise ValueError(
|
||||
"update parameter dictionary must not be empty"
|
||||
)
|
||||
elif isinstance(update, ColumnCollection):
|
||||
update = dict(update)
|
||||
else:
|
||||
raise ValueError(
|
||||
"update parameter must be a non-empty dictionary "
|
||||
"or a ColumnCollection such as the `.c.` collection "
|
||||
"of a Table object"
|
||||
)
|
||||
self.update = update
|
||||
|
||||
|
||||
_UpdateArg = Union[
|
||||
Mapping[Any, Any], List[Tuple[str, Any]], ColumnCollection[Any, Any]
|
||||
]
|
|
@ -0,0 +1,244 @@
|
|||
# dialects/mysql/enumerated.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from .types import _StringType
|
||||
from ... import exc
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
|
||||
class ENUM(sqltypes.NativeForEmulated, sqltypes.Enum, _StringType):
|
||||
"""MySQL ENUM type."""
|
||||
|
||||
__visit_name__ = "ENUM"
|
||||
|
||||
native_enum = True
|
||||
|
||||
def __init__(self, *enums, **kw):
|
||||
"""Construct an ENUM.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myenum', ENUM("foo", "bar", "baz"))
|
||||
|
||||
:param enums: The range of valid values for this ENUM. Values in
|
||||
enums are not quoted, they will be escaped and surrounded by single
|
||||
quotes when generating the schema. This object may also be a
|
||||
PEP-435-compliant enumerated type.
|
||||
|
||||
.. versionadded: 1.1 added support for PEP-435-compliant enumerated
|
||||
types.
|
||||
|
||||
:param strict: This flag has no effect.
|
||||
|
||||
.. versionchanged:: The MySQL ENUM type as well as the base Enum
|
||||
type now validates all Python data values.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
kw.pop("strict", None)
|
||||
self._enum_init(enums, kw)
|
||||
_StringType.__init__(self, length=self.length, **kw)
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(cls, impl, **kw):
|
||||
"""Produce a MySQL native :class:`.mysql.ENUM` from plain
|
||||
:class:`.Enum`.
|
||||
|
||||
"""
|
||||
kw.setdefault("validate_strings", impl.validate_strings)
|
||||
kw.setdefault("values_callable", impl.values_callable)
|
||||
kw.setdefault("omit_aliases", impl._omit_aliases)
|
||||
return cls(**kw)
|
||||
|
||||
def _object_value_for_elem(self, elem):
|
||||
# mysql sends back a blank string for any value that
|
||||
# was persisted that was not in the enums; that is, it does no
|
||||
# validation on the incoming data, it "truncates" it to be
|
||||
# the blank string. Return it straight.
|
||||
if elem == "":
|
||||
return elem
|
||||
else:
|
||||
return super()._object_value_for_elem(elem)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[ENUM, _StringType, sqltypes.Enum]
|
||||
)
|
||||
|
||||
|
||||
class SET(_StringType):
|
||||
"""MySQL SET type."""
|
||||
|
||||
__visit_name__ = "SET"
|
||||
|
||||
def __init__(self, *values, **kw):
|
||||
"""Construct a SET.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myset', SET("foo", "bar", "baz"))
|
||||
|
||||
|
||||
The list of potential values is required in the case that this
|
||||
set will be used to generate DDL for a table, or if the
|
||||
:paramref:`.SET.retrieve_as_bitwise` flag is set to True.
|
||||
|
||||
:param values: The range of valid values for this SET. The values
|
||||
are not quoted, they will be escaped and surrounded by single
|
||||
quotes when generating the schema.
|
||||
|
||||
:param convert_unicode: Same flag as that of
|
||||
:paramref:`.String.convert_unicode`.
|
||||
|
||||
:param collation: same as that of :paramref:`.String.collation`
|
||||
|
||||
:param charset: same as that of :paramref:`.VARCHAR.charset`.
|
||||
|
||||
:param ascii: same as that of :paramref:`.VARCHAR.ascii`.
|
||||
|
||||
:param unicode: same as that of :paramref:`.VARCHAR.unicode`.
|
||||
|
||||
:param binary: same as that of :paramref:`.VARCHAR.binary`.
|
||||
|
||||
:param retrieve_as_bitwise: if True, the data for the set type will be
|
||||
persisted and selected using an integer value, where a set is coerced
|
||||
into a bitwise mask for persistence. MySQL allows this mode which
|
||||
has the advantage of being able to store values unambiguously,
|
||||
such as the blank string ``''``. The datatype will appear
|
||||
as the expression ``col + 0`` in a SELECT statement, so that the
|
||||
value is coerced into an integer value in result sets.
|
||||
This flag is required if one wishes
|
||||
to persist a set that can store the blank string ``''`` as a value.
|
||||
|
||||
.. warning::
|
||||
|
||||
When using :paramref:`.mysql.SET.retrieve_as_bitwise`, it is
|
||||
essential that the list of set values is expressed in the
|
||||
**exact same order** as exists on the MySQL database.
|
||||
|
||||
"""
|
||||
self.retrieve_as_bitwise = kw.pop("retrieve_as_bitwise", False)
|
||||
self.values = tuple(values)
|
||||
if not self.retrieve_as_bitwise and "" in values:
|
||||
raise exc.ArgumentError(
|
||||
"Can't use the blank value '' in a SET without "
|
||||
"setting retrieve_as_bitwise=True"
|
||||
)
|
||||
if self.retrieve_as_bitwise:
|
||||
self._bitmap = {
|
||||
value: 2**idx for idx, value in enumerate(self.values)
|
||||
}
|
||||
self._bitmap.update(
|
||||
(2**idx, value) for idx, value in enumerate(self.values)
|
||||
)
|
||||
length = max([len(v) for v in values] + [0])
|
||||
kw.setdefault("length", length)
|
||||
super().__init__(**kw)
|
||||
|
||||
def column_expression(self, colexpr):
|
||||
if self.retrieve_as_bitwise:
|
||||
return sql.type_coerce(
|
||||
sql.type_coerce(colexpr, sqltypes.Integer) + 0, self
|
||||
)
|
||||
else:
|
||||
return colexpr
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
value = int(value)
|
||||
|
||||
return set(util.map_bits(self._bitmap.__getitem__, value))
|
||||
else:
|
||||
return None
|
||||
|
||||
else:
|
||||
super_convert = super().result_processor(dialect, coltype)
|
||||
|
||||
def process(value):
|
||||
if isinstance(value, str):
|
||||
# MySQLdb returns a string, let's parse
|
||||
if super_convert:
|
||||
value = super_convert(value)
|
||||
return set(re.findall(r"[^,]+", value))
|
||||
else:
|
||||
# mysql-connector-python does a naive
|
||||
# split(",") which throws in an empty string
|
||||
if value is not None:
|
||||
value.discard("")
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_convert = super().bind_processor(dialect)
|
||||
if self.retrieve_as_bitwise:
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
elif isinstance(value, (int, str)):
|
||||
if super_convert:
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
else:
|
||||
int_value = 0
|
||||
for v in value:
|
||||
int_value |= self._bitmap[v]
|
||||
return int_value
|
||||
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
# accept strings and int (actually bitflag) values directly
|
||||
if value is not None and not isinstance(value, (int, str)):
|
||||
value = ",".join(value)
|
||||
|
||||
if super_convert:
|
||||
return super_convert(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def adapt(self, impltype, **kw):
|
||||
kw["retrieve_as_bitwise"] = self.retrieve_as_bitwise
|
||||
return util.constructor_copy(self, impltype, *self.values, **kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self,
|
||||
to_inspect=[SET, _StringType],
|
||||
additional_kw=[
|
||||
("retrieve_as_bitwise", False),
|
||||
],
|
||||
)
|
|
@ -0,0 +1,141 @@
|
|||
# dialects/mysql/expression.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...sql import coercions
|
||||
from ...sql import elements
|
||||
from ...sql import operators
|
||||
from ...sql import roles
|
||||
from ...sql.base import _generative
|
||||
from ...sql.base import Generative
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
class match(Generative, elements.BinaryExpression):
|
||||
"""Produce a ``MATCH (X, Y) AGAINST ('TEXT')`` clause.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy.dialects.mysql import match
|
||||
|
||||
match_expr = match(
|
||||
users_table.c.firstname,
|
||||
users_table.c.lastname,
|
||||
against="Firstname Lastname",
|
||||
)
|
||||
|
||||
stmt = (
|
||||
select(users_table)
|
||||
.where(match_expr.in_boolean_mode())
|
||||
.order_by(desc(match_expr))
|
||||
)
|
||||
|
||||
Would produce SQL resembling::
|
||||
|
||||
SELECT id, firstname, lastname
|
||||
FROM user
|
||||
WHERE MATCH(firstname, lastname) AGAINST (:param_1 IN BOOLEAN MODE)
|
||||
ORDER BY MATCH(firstname, lastname) AGAINST (:param_2) DESC
|
||||
|
||||
The :func:`_mysql.match` function is a standalone version of the
|
||||
:meth:`_sql.ColumnElement.match` method available on all
|
||||
SQL expressions, as when :meth:`_expression.ColumnElement.match` is
|
||||
used, but allows to pass multiple columns
|
||||
|
||||
:param cols: column expressions to match against
|
||||
|
||||
:param against: expression to be compared towards
|
||||
|
||||
:param in_boolean_mode: boolean, set "boolean mode" to true
|
||||
|
||||
:param in_natural_language_mode: boolean , set "natural language" to true
|
||||
|
||||
:param with_query_expansion: boolean, set "query expansion" to true
|
||||
|
||||
.. versionadded:: 1.4.19
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_expression.ColumnElement.match`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "mysql_match"
|
||||
|
||||
inherit_cache = True
|
||||
|
||||
def __init__(self, *cols, **kw):
|
||||
if not cols:
|
||||
raise exc.ArgumentError("columns are required")
|
||||
|
||||
against = kw.pop("against", None)
|
||||
|
||||
if against is None:
|
||||
raise exc.ArgumentError("against is required")
|
||||
against = coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
against,
|
||||
)
|
||||
|
||||
left = elements.BooleanClauseList._construct_raw(
|
||||
operators.comma_op,
|
||||
clauses=cols,
|
||||
)
|
||||
left.group = False
|
||||
|
||||
flags = util.immutabledict(
|
||||
{
|
||||
"mysql_boolean_mode": kw.pop("in_boolean_mode", False),
|
||||
"mysql_natural_language": kw.pop(
|
||||
"in_natural_language_mode", False
|
||||
),
|
||||
"mysql_query_expansion": kw.pop("with_query_expansion", False),
|
||||
}
|
||||
)
|
||||
|
||||
if kw:
|
||||
raise exc.ArgumentError("unknown arguments: %s" % (", ".join(kw)))
|
||||
|
||||
super().__init__(left, against, operators.match_op, modifiers=flags)
|
||||
|
||||
@_generative
|
||||
def in_boolean_mode(self) -> Self:
|
||||
"""Apply the "IN BOOLEAN MODE" modifier to the MATCH expression.
|
||||
|
||||
:return: a new :class:`_mysql.match` instance with modifications
|
||||
applied.
|
||||
"""
|
||||
|
||||
self.modifiers = self.modifiers.union({"mysql_boolean_mode": True})
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def in_natural_language_mode(self) -> Self:
|
||||
"""Apply the "IN NATURAL LANGUAGE MODE" modifier to the MATCH
|
||||
expression.
|
||||
|
||||
:return: a new :class:`_mysql.match` instance with modifications
|
||||
applied.
|
||||
"""
|
||||
|
||||
self.modifiers = self.modifiers.union({"mysql_natural_language": True})
|
||||
return self
|
||||
|
||||
@_generative
|
||||
def with_query_expansion(self) -> Self:
|
||||
"""Apply the "WITH QUERY EXPANSION" modifier to the MATCH expression.
|
||||
|
||||
:return: a new :class:`_mysql.match` instance with modifications
|
||||
applied.
|
||||
"""
|
||||
|
||||
self.modifiers = self.modifiers.union({"mysql_query_expansion": True})
|
||||
return self
|
|
@ -0,0 +1,81 @@
|
|||
# dialects/mysql/json.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""MySQL JSON type.
|
||||
|
||||
MySQL supports JSON as of version 5.7.
|
||||
MariaDB supports JSON (as an alias for LONGTEXT) as of version 10.2.
|
||||
|
||||
:class:`_mysql.JSON` is used automatically whenever the base
|
||||
:class:`_types.JSON` datatype is used against a MySQL or MariaDB backend.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON` - main documentation for the generic
|
||||
cross-platform JSON datatype.
|
||||
|
||||
The :class:`.mysql.JSON` type supports persistence of JSON values
|
||||
as well as the core index operations provided by :class:`_types.JSON`
|
||||
datatype, by adapting the operations to render the ``JSON_EXTRACT``
|
||||
function at the database level.
|
||||
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class _FormatTypeMixin:
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join(
|
||||
[
|
||||
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
|
||||
for elem in value
|
||||
]
|
||||
)
|
||||
)
|
|
@ -0,0 +1,32 @@
|
|||
# dialects/mysql/mariadb.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from .base import MariaDBIdentifierPreparer
|
||||
from .base import MySQLDialect
|
||||
|
||||
|
||||
class MariaDBDialect(MySQLDialect):
|
||||
is_mariadb = True
|
||||
supports_statement_cache = True
|
||||
name = "mariadb"
|
||||
preparer = MariaDBIdentifierPreparer
|
||||
|
||||
|
||||
def loader(driver):
|
||||
driver_mod = __import__(
|
||||
"sqlalchemy.dialects.mysql.%s" % driver
|
||||
).dialects.mysql
|
||||
driver_cls = getattr(driver_mod, driver).dialect
|
||||
|
||||
return type(
|
||||
"MariaDBDialect_%s" % driver,
|
||||
(
|
||||
MariaDBDialect,
|
||||
driver_cls,
|
||||
),
|
||||
{"supports_statement_cache": True},
|
||||
)
|
|
@ -0,0 +1,282 @@
|
|||
# dialects/mysql/mariadbconnector.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+mariadbconnector
|
||||
:name: MariaDB Connector/Python
|
||||
:dbapi: mariadb
|
||||
:connectstring: mariadb+mariadbconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: https://pypi.org/project/mariadb/
|
||||
|
||||
Driver Status
|
||||
-------------
|
||||
|
||||
MariaDB Connector/Python enables Python programs to access MariaDB and MySQL
|
||||
databases using an API which is compliant with the Python DB API 2.0 (PEP-249).
|
||||
It is written in C and uses MariaDB Connector/C client library for client server
|
||||
communication.
|
||||
|
||||
Note that the default driver for a ``mariadb://`` connection URI continues to
|
||||
be ``mysqldb``. ``mariadb+mariadbconnector://`` is required to use this driver.
|
||||
|
||||
.. mariadb: https://github.com/mariadb-corporation/mariadb-connector-python
|
||||
|
||||
""" # noqa
|
||||
import re
|
||||
from uuid import UUID as _python_UUID
|
||||
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from ... import sql
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
|
||||
mariadb_cpy_minimum_version = (1, 0, 1)
|
||||
|
||||
|
||||
class _MariaDBUUID(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
# work around JIRA issue
|
||||
# https://jira.mariadb.org/browse/CONPY-270. When that issue is fixed,
|
||||
# this type can be removed.
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.as_uuid:
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
value = _python_UUID(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
else:
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
if hasattr(value, "decode"):
|
||||
value = value.decode("ascii")
|
||||
value = str(_python_UUID(value))
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_mariadbconnector(MySQLExecutionContext):
|
||||
_lastrowid = None
|
||||
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=False)
|
||||
|
||||
def create_default_cursor(self):
|
||||
return self._dbapi_connection.cursor(buffered=True)
|
||||
|
||||
def post_exec(self):
|
||||
super().post_exec()
|
||||
|
||||
self._rowcount = self.cursor.rowcount
|
||||
|
||||
if self.isinsert and self.compiled.postfetch_lastrowid:
|
||||
self._lastrowid = self.cursor.lastrowid
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
if self._rowcount is not None:
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
def get_lastrowid(self):
|
||||
return self._lastrowid
|
||||
|
||||
|
||||
class MySQLCompiler_mariadbconnector(MySQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class MySQLDialect_mariadbconnector(MySQLDialect):
|
||||
driver = "mariadbconnector"
|
||||
supports_statement_cache = True
|
||||
|
||||
# set this to True at the module level to prevent the driver from running
|
||||
# against a backend that server detects as MySQL. currently this appears to
|
||||
# be unnecessary as MariaDB client libraries have always worked against
|
||||
# MySQL databases. However, if this changes at some point, this can be
|
||||
# adjusted, but PLEASE ADD A TEST in test/dialect/mysql/test_dialect.py if
|
||||
# this change is made at some point to ensure the correct exception
|
||||
# is raised at the correct point when running the driver against
|
||||
# a MySQL backend.
|
||||
# is_mariadb = True
|
||||
|
||||
supports_unicode_statements = True
|
||||
encoding = "utf8mb4"
|
||||
convert_unicode = True
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
supports_native_decimal = True
|
||||
default_paramstyle = "qmark"
|
||||
execution_ctx_cls = MySQLExecutionContext_mariadbconnector
|
||||
statement_compiler = MySQLCompiler_mariadbconnector
|
||||
|
||||
supports_server_side_cursors = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
MySQLDialect.colspecs, {sqltypes.Uuid: _MariaDBUUID}
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def _dbapi_version(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
return tuple(
|
||||
[
|
||||
int(x)
|
||||
for x in re.findall(
|
||||
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return (99, 99, 99)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.paramstyle = "qmark"
|
||||
if self.dbapi is not None:
|
||||
if self._dbapi_version < mariadb_cpy_minimum_version:
|
||||
raise NotImplementedError(
|
||||
"The minimum required version for MariaDB "
|
||||
"Connector/Python is %s"
|
||||
% ".".join(str(x) for x in mariadb_cpy_minimum_version)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("mariadb")
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return "not connected" in str_e or "isn't valid" in str_e
|
||||
else:
|
||||
return False
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args()
|
||||
|
||||
int_params = [
|
||||
"connect_timeout",
|
||||
"read_timeout",
|
||||
"write_timeout",
|
||||
"client_flag",
|
||||
"port",
|
||||
"pool_size",
|
||||
]
|
||||
bool_params = [
|
||||
"local_infile",
|
||||
"ssl_verify_cert",
|
||||
"ssl",
|
||||
"pool_reset_connection",
|
||||
]
|
||||
|
||||
for key in int_params:
|
||||
util.coerce_kw_type(opts, key, int)
|
||||
for key in bool_params:
|
||||
util.coerce_kw_type(opts, key, bool)
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
client_flag = opts.get("client_flag", 0)
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
CLIENT_FLAGS = __import__(
|
||||
self.dbapi.__name__ + ".constants.CLIENT"
|
||||
).constants.CLIENT
|
||||
client_flag |= CLIENT_FLAGS.FOUND_ROWS
|
||||
except (AttributeError, ImportError):
|
||||
self.supports_sane_rowcount = False
|
||||
opts["client_flag"] = client_flag
|
||||
return [[], opts]
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
try:
|
||||
rc = exception.errno
|
||||
except:
|
||||
rc = -1
|
||||
return rc
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return "utf8mb4"
|
||||
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
connection.autocommit = True
|
||||
else:
|
||||
connection.autocommit = False
|
||||
super().set_isolation_level(connection, level)
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA BEGIN :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sql.text("XA PREPARE :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
def do_rollback_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
connection.execute(
|
||||
sql.text("XA END :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
connection.execute(
|
||||
sql.text("XA ROLLBACK :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
def do_commit_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if not is_prepared:
|
||||
self.do_prepare_twophase(connection, xid)
|
||||
connection.execute(
|
||||
sql.text("XA COMMIT :xid").bindparams(
|
||||
sql.bindparam("xid", xid, literal_execute=True)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mariadbconnector
|
|
@ -0,0 +1,179 @@
|
|||
# dialects/mysql/mysqlconnector.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
.. dialect:: mysql+mysqlconnector
|
||||
:name: MySQL Connector/Python
|
||||
:dbapi: myconnpy
|
||||
:connectstring: mysql+mysqlconnector://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: https://pypi.org/project/mysql-connector-python/
|
||||
|
||||
.. note::
|
||||
|
||||
The MySQL Connector/Python DBAPI has had many issues since its release,
|
||||
some of which may remain unresolved, and the mysqlconnector dialect is
|
||||
**not tested as part of SQLAlchemy's continuous integration**.
|
||||
The recommended MySQL dialects are mysqlclient and PyMySQL.
|
||||
|
||||
""" # noqa
|
||||
|
||||
import re
|
||||
|
||||
from .base import BIT
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from ... import util
|
||||
|
||||
|
||||
class MySQLCompiler_mysqlconnector(MySQLCompiler):
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return (
|
||||
self.process(binary.left, **kw)
|
||||
+ " % "
|
||||
+ self.process(binary.right, **kw)
|
||||
)
|
||||
|
||||
|
||||
class MySQLIdentifierPreparer_mysqlconnector(MySQLIdentifierPreparer):
|
||||
@property
|
||||
def _double_percents(self):
|
||||
return False
|
||||
|
||||
@_double_percents.setter
|
||||
def _double_percents(self, value):
|
||||
pass
|
||||
|
||||
def _escape_identifier(self, value):
|
||||
value = value.replace(self.escape_quote, self.escape_to_quote)
|
||||
return value
|
||||
|
||||
|
||||
class _myconnpyBIT(BIT):
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""MySQL-connector already converts mysql bits, so."""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class MySQLDialect_mysqlconnector(MySQLDialect):
|
||||
driver = "mysqlconnector"
|
||||
supports_statement_cache = True
|
||||
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = "format"
|
||||
statement_compiler = MySQLCompiler_mysqlconnector
|
||||
|
||||
preparer = MySQLIdentifierPreparer_mysqlconnector
|
||||
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {BIT: _myconnpyBIT})
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
from mysql import connector
|
||||
|
||||
return connector
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping(False)
|
||||
return True
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user")
|
||||
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, "allow_local_infile", bool)
|
||||
util.coerce_kw_type(opts, "autocommit", bool)
|
||||
util.coerce_kw_type(opts, "buffered", bool)
|
||||
util.coerce_kw_type(opts, "compress", bool)
|
||||
util.coerce_kw_type(opts, "connection_timeout", int)
|
||||
util.coerce_kw_type(opts, "connect_timeout", int)
|
||||
util.coerce_kw_type(opts, "consume_results", bool)
|
||||
util.coerce_kw_type(opts, "force_ipv6", bool)
|
||||
util.coerce_kw_type(opts, "get_warnings", bool)
|
||||
util.coerce_kw_type(opts, "pool_reset_session", bool)
|
||||
util.coerce_kw_type(opts, "pool_size", int)
|
||||
util.coerce_kw_type(opts, "raise_on_warnings", bool)
|
||||
util.coerce_kw_type(opts, "raw", bool)
|
||||
util.coerce_kw_type(opts, "ssl_verify_cert", bool)
|
||||
util.coerce_kw_type(opts, "use_pure", bool)
|
||||
util.coerce_kw_type(opts, "use_unicode", bool)
|
||||
|
||||
# unfortunately, MySQL/connector python refuses to release a
|
||||
# cursor without reading fully, so non-buffered isn't an option
|
||||
opts.setdefault("buffered", True)
|
||||
|
||||
# FOUND_ROWS must be set in ClientFlag to enable
|
||||
# supports_sane_rowcount.
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
from mysql.connector.constants import ClientFlag
|
||||
|
||||
client_flags = opts.get(
|
||||
"client_flags", ClientFlag.get_default()
|
||||
)
|
||||
client_flags |= ClientFlag.FOUND_ROWS
|
||||
opts["client_flags"] = client_flags
|
||||
except Exception:
|
||||
pass
|
||||
return [[], opts]
|
||||
|
||||
@util.memoized_property
|
||||
def _mysqlconnector_version_info(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
return connection.connection.charset
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.errno
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
errnos = (2006, 2013, 2014, 2045, 2055, 2048)
|
||||
exceptions = (self.dbapi.OperationalError, self.dbapi.InterfaceError)
|
||||
if isinstance(e, exceptions):
|
||||
return (
|
||||
e.errno in errnos
|
||||
or "MySQL Connection not available." in str(e)
|
||||
or "Connection to MySQL is not available" in str(e)
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _compat_fetchall(self, rp, charset=None):
|
||||
return rp.fetchall()
|
||||
|
||||
def _compat_fetchone(self, rp, charset=None):
|
||||
return rp.fetchone()
|
||||
|
||||
_isolation_lookup = {
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
}
|
||||
|
||||
def _set_isolation_level(self, connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
connection.autocommit = True
|
||||
else:
|
||||
connection.autocommit = False
|
||||
super()._set_isolation_level(connection, level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqlconnector
|
|
@ -0,0 +1,308 @@
|
|||
# dialects/mysql/mysqldb.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
|
||||
.. dialect:: mysql+mysqldb
|
||||
:name: mysqlclient (maintained fork of MySQL-Python)
|
||||
:dbapi: mysqldb
|
||||
:connectstring: mysql+mysqldb://<user>:<password>@<host>[:<port>]/<dbname>
|
||||
:url: https://pypi.org/project/mysqlclient/
|
||||
|
||||
Driver Status
|
||||
-------------
|
||||
|
||||
The mysqlclient DBAPI is a maintained fork of the
|
||||
`MySQL-Python <https://sourceforge.net/projects/mysql-python>`_ DBAPI
|
||||
that is no longer maintained. `mysqlclient`_ supports Python 2 and Python 3
|
||||
and is very stable.
|
||||
|
||||
.. _mysqlclient: https://github.com/PyMySQL/mysqlclient-python
|
||||
|
||||
.. _mysqldb_unicode:
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
.. _mysqldb_ssl:
|
||||
|
||||
SSL Connections
|
||||
----------------
|
||||
|
||||
The mysqlclient and PyMySQL DBAPIs accept an additional dictionary under the
|
||||
key "ssl", which may be specified using the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary::
|
||||
|
||||
engine = create_engine(
|
||||
"mysql+mysqldb://scott:tiger@192.168.0.134/test",
|
||||
connect_args={
|
||||
"ssl": {
|
||||
"ca": "/home/gord/client-ssl/ca.pem",
|
||||
"cert": "/home/gord/client-ssl/client-cert.pem",
|
||||
"key": "/home/gord/client-ssl/client-key.pem"
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
For convenience, the following keys may also be specified inline within the URL
|
||||
where they will be interpreted into the "ssl" dictionary automatically:
|
||||
"ssl_ca", "ssl_cert", "ssl_key", "ssl_capath", "ssl_cipher",
|
||||
"ssl_check_hostname". An example is as follows::
|
||||
|
||||
connection_uri = (
|
||||
"mysql+mysqldb://scott:tiger@192.168.0.134/test"
|
||||
"?ssl_ca=/home/gord/client-ssl/ca.pem"
|
||||
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
|
||||
"&ssl_key=/home/gord/client-ssl/client-key.pem"
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`pymysql_ssl` in the PyMySQL dialect
|
||||
|
||||
|
||||
Using MySQLdb with Google Cloud SQL
|
||||
-----------------------------------
|
||||
|
||||
Google Cloud SQL now recommends use of the MySQLdb dialect. Connect
|
||||
using a URL like the following::
|
||||
|
||||
mysql+mysqldb://root@/<dbname>?unix_socket=/cloudsql/<projectid>:<instancename>
|
||||
|
||||
Server Side Cursors
|
||||
-------------------
|
||||
|
||||
The mysqldb dialect supports server-side cursors. See :ref:`mysql_ss_cursors`.
|
||||
|
||||
"""
|
||||
|
||||
import re
|
||||
|
||||
from .base import MySQLCompiler
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .base import MySQLIdentifierPreparer
|
||||
from .base import TEXT
|
||||
from ... import sql
|
||||
from ... import util
|
||||
|
||||
|
||||
class MySQLExecutionContext_mysqldb(MySQLExecutionContext):
|
||||
@property
|
||||
def rowcount(self):
|
||||
if hasattr(self, "_rowcount"):
|
||||
return self._rowcount
|
||||
else:
|
||||
return self.cursor.rowcount
|
||||
|
||||
|
||||
class MySQLCompiler_mysqldb(MySQLCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class MySQLDialect_mysqldb(MySQLDialect):
|
||||
driver = "mysqldb"
|
||||
supports_statement_cache = True
|
||||
supports_unicode_statements = True
|
||||
supports_sane_rowcount = True
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
supports_native_decimal = True
|
||||
|
||||
default_paramstyle = "format"
|
||||
execution_ctx_cls = MySQLExecutionContext_mysqldb
|
||||
statement_compiler = MySQLCompiler_mysqldb
|
||||
preparer = MySQLIdentifierPreparer
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self._mysql_dbapi_version = (
|
||||
self._parse_dbapi_version(self.dbapi.__version__)
|
||||
if self.dbapi is not None and hasattr(self.dbapi, "__version__")
|
||||
else (0, 0, 0)
|
||||
)
|
||||
|
||||
def _parse_dbapi_version(self, version):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", version)
|
||||
if m:
|
||||
return tuple(int(x) for x in m.group(1, 2, 3) if x is not None)
|
||||
else:
|
||||
return (0, 0, 0)
|
||||
|
||||
@util.langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("MySQLdb.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
return True
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("MySQLdb")
|
||||
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
charset_name = conn.character_set_name()
|
||||
|
||||
if charset_name is not None:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SET NAMES %s" % charset_name)
|
||||
cursor.close()
|
||||
|
||||
return on_connect
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
dbapi_connection.ping()
|
||||
return True
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
rowcount = cursor.executemany(statement, parameters)
|
||||
if context is not None:
|
||||
context._rowcount = rowcount
|
||||
|
||||
def _check_unicode_returns(self, connection):
|
||||
# work around issue fixed in
|
||||
# https://github.com/farcepest/MySQLdb1/commit/cd44524fef63bd3fcb71947392326e9742d520e8
|
||||
# specific issue w/ the utf8mb4_bin collation and unicode returns
|
||||
|
||||
collation = connection.exec_driver_sql(
|
||||
"show collation where %s = 'utf8mb4' and %s = 'utf8mb4_bin'"
|
||||
% (
|
||||
self.identifier_preparer.quote("Charset"),
|
||||
self.identifier_preparer.quote("Collation"),
|
||||
)
|
||||
).scalar()
|
||||
has_utf8mb4_bin = self.server_version_info > (5,) and collation
|
||||
if has_utf8mb4_bin:
|
||||
additional_tests = [
|
||||
sql.collate(
|
||||
sql.cast(
|
||||
sql.literal_column("'test collated returns'"),
|
||||
TEXT(charset="utf8mb4"),
|
||||
),
|
||||
"utf8mb4_bin",
|
||||
)
|
||||
]
|
||||
else:
|
||||
additional_tests = []
|
||||
return super()._check_unicode_returns(connection, additional_tests)
|
||||
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(
|
||||
database="db", username="user", password="passwd"
|
||||
)
|
||||
|
||||
opts = url.translate_connect_args(**_translate_args)
|
||||
opts.update(url.query)
|
||||
|
||||
util.coerce_kw_type(opts, "compress", bool)
|
||||
util.coerce_kw_type(opts, "connect_timeout", int)
|
||||
util.coerce_kw_type(opts, "read_timeout", int)
|
||||
util.coerce_kw_type(opts, "write_timeout", int)
|
||||
util.coerce_kw_type(opts, "client_flag", int)
|
||||
util.coerce_kw_type(opts, "local_infile", int)
|
||||
# Note: using either of the below will cause all strings to be
|
||||
# returned as Unicode, both in raw SQL operations and with column
|
||||
# types like String and MSString.
|
||||
util.coerce_kw_type(opts, "use_unicode", bool)
|
||||
util.coerce_kw_type(opts, "charset", str)
|
||||
|
||||
# Rich values 'cursorclass' and 'conv' are not supported via
|
||||
# query string.
|
||||
|
||||
ssl = {}
|
||||
keys = [
|
||||
("ssl_ca", str),
|
||||
("ssl_key", str),
|
||||
("ssl_cert", str),
|
||||
("ssl_capath", str),
|
||||
("ssl_cipher", str),
|
||||
("ssl_check_hostname", bool),
|
||||
]
|
||||
for key, kw_type in keys:
|
||||
if key in opts:
|
||||
ssl[key[4:]] = opts[key]
|
||||
util.coerce_kw_type(ssl, key[4:], kw_type)
|
||||
del opts[key]
|
||||
if ssl:
|
||||
opts["ssl"] = ssl
|
||||
|
||||
# FOUND_ROWS must be set in CLIENT_FLAGS to enable
|
||||
# supports_sane_rowcount.
|
||||
client_flag = opts.get("client_flag", 0)
|
||||
|
||||
client_flag_found_rows = self._found_rows_client_flag()
|
||||
if client_flag_found_rows is not None:
|
||||
client_flag |= client_flag_found_rows
|
||||
opts["client_flag"] = client_flag
|
||||
return [[], opts]
|
||||
|
||||
def _found_rows_client_flag(self):
|
||||
if self.dbapi is not None:
|
||||
try:
|
||||
CLIENT_FLAGS = __import__(
|
||||
self.dbapi.__name__ + ".constants.CLIENT"
|
||||
).constants.CLIENT
|
||||
except (AttributeError, ImportError):
|
||||
return None
|
||||
else:
|
||||
return CLIENT_FLAGS.FOUND_ROWS
|
||||
else:
|
||||
return None
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
return exception.args[0]
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
try:
|
||||
# note: the SQL here would be
|
||||
# "SHOW VARIABLES LIKE 'character_set%%'"
|
||||
cset_name = connection.connection.character_set_name
|
||||
except AttributeError:
|
||||
util.warn(
|
||||
"No 'character_set_name' can be detected with "
|
||||
"this MySQL-Python version; "
|
||||
"please upgrade to a recent version of MySQL-Python. "
|
||||
"Assuming latin1."
|
||||
)
|
||||
return "latin1"
|
||||
else:
|
||||
return cset_name()
|
||||
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"SERIALIZABLE",
|
||||
"READ UNCOMMITTED",
|
||||
"READ COMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"AUTOCOMMIT",
|
||||
)
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit(True)
|
||||
else:
|
||||
dbapi_connection.autocommit(False)
|
||||
super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
|
||||
dialect = MySQLDialect_mysqldb
|
|
@ -0,0 +1,107 @@
|
|||
# dialects/mysql/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import exc
|
||||
from ...testing.provision import configure_follower
|
||||
from ...testing.provision import create_db
|
||||
from ...testing.provision import drop_db
|
||||
from ...testing.provision import generate_driver_url
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
from ...testing.provision import upsert
|
||||
|
||||
|
||||
@generate_driver_url.for_db("mysql", "mariadb")
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
backend = url.get_backend_name()
|
||||
|
||||
# NOTE: at the moment, tests are running mariadbconnector
|
||||
# against both mariadb and mysql backends. if we want this to be
|
||||
# limited, do the decision making here to reject a "mysql+mariadbconnector"
|
||||
# URL. Optionally also re-enable the module level
|
||||
# MySQLDialect_mariadbconnector.is_mysql flag as well, which must include
|
||||
# a unit and/or functional test.
|
||||
|
||||
# all the Jenkins tests have been running mysqlclient Python library
|
||||
# built against mariadb client drivers for years against all MySQL /
|
||||
# MariaDB versions going back to MySQL 5.6, currently they can talk
|
||||
# to MySQL databases without problems.
|
||||
|
||||
if backend == "mysql":
|
||||
dialect_cls = url.get_dialect()
|
||||
if dialect_cls._is_mariadb_from_url(url):
|
||||
backend = "mariadb"
|
||||
|
||||
new_url = url.set(
|
||||
drivername="%s+%s" % (backend, driver)
|
||||
).update_query_string(query_str)
|
||||
|
||||
try:
|
||||
new_url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return new_url
|
||||
|
||||
|
||||
@create_db.for_db("mysql", "mariadb")
|
||||
def _mysql_create_db(cfg, eng, ident):
|
||||
with eng.begin() as conn:
|
||||
try:
|
||||
_mysql_drop_db(cfg, conn, ident)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with eng.begin() as conn:
|
||||
conn.exec_driver_sql(
|
||||
"CREATE DATABASE %s CHARACTER SET utf8mb4" % ident
|
||||
)
|
||||
conn.exec_driver_sql(
|
||||
"CREATE DATABASE %s_test_schema CHARACTER SET utf8mb4" % ident
|
||||
)
|
||||
conn.exec_driver_sql(
|
||||
"CREATE DATABASE %s_test_schema_2 CHARACTER SET utf8mb4" % ident
|
||||
)
|
||||
|
||||
|
||||
@configure_follower.for_db("mysql", "mariadb")
|
||||
def _mysql_configure_follower(config, ident):
|
||||
config.test_schema = "%s_test_schema" % ident
|
||||
config.test_schema_2 = "%s_test_schema_2" % ident
|
||||
|
||||
|
||||
@drop_db.for_db("mysql", "mariadb")
|
||||
def _mysql_drop_db(cfg, eng, ident):
|
||||
with eng.begin() as conn:
|
||||
conn.exec_driver_sql("DROP DATABASE %s_test_schema" % ident)
|
||||
conn.exec_driver_sql("DROP DATABASE %s_test_schema_2" % ident)
|
||||
conn.exec_driver_sql("DROP DATABASE %s" % ident)
|
||||
|
||||
|
||||
@temp_table_keyword_args.for_db("mysql", "mariadb")
|
||||
def _mysql_temp_table_keyword_args(cfg, eng):
|
||||
return {"prefixes": ["TEMPORARY"]}
|
||||
|
||||
|
||||
@upsert.for_db("mariadb")
|
||||
def _upsert(
|
||||
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
|
||||
):
|
||||
from sqlalchemy.dialects.mysql import insert
|
||||
|
||||
stmt = insert(table)
|
||||
|
||||
if set_lambda:
|
||||
stmt = stmt.on_duplicate_key_update(**set_lambda(stmt.inserted))
|
||||
else:
|
||||
pk1 = table.primary_key.c[0]
|
||||
stmt = stmt.on_duplicate_key_update({pk1.key: pk1})
|
||||
|
||||
stmt = stmt.returning(
|
||||
*returning, sort_by_parameter_order=sort_by_parameter_order
|
||||
)
|
||||
return stmt
|
|
@ -0,0 +1,137 @@
|
|||
# dialects/mysql/pymysql.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
.. dialect:: mysql+pymysql
|
||||
:name: PyMySQL
|
||||
:dbapi: pymysql
|
||||
:connectstring: mysql+pymysql://<username>:<password>@<host>/<dbname>[?<options>]
|
||||
:url: https://pymysql.readthedocs.io/
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
Please see :ref:`mysql_unicode` for current recommendations on unicode
|
||||
handling.
|
||||
|
||||
.. _pymysql_ssl:
|
||||
|
||||
SSL Connections
|
||||
------------------
|
||||
|
||||
The PyMySQL DBAPI accepts the same SSL arguments as that of MySQLdb,
|
||||
described at :ref:`mysqldb_ssl`. See that section for additional examples.
|
||||
|
||||
If the server uses an automatically-generated certificate that is self-signed
|
||||
or does not match the host name (as seen from the client), it may also be
|
||||
necessary to indicate ``ssl_check_hostname=false`` in PyMySQL::
|
||||
|
||||
connection_uri = (
|
||||
"mysql+pymysql://scott:tiger@192.168.0.134/test"
|
||||
"?ssl_ca=/home/gord/client-ssl/ca.pem"
|
||||
"&ssl_cert=/home/gord/client-ssl/client-cert.pem"
|
||||
"&ssl_key=/home/gord/client-ssl/client-key.pem"
|
||||
"&ssl_check_hostname=false"
|
||||
)
|
||||
|
||||
|
||||
MySQL-Python Compatibility
|
||||
--------------------------
|
||||
|
||||
The pymysql DBAPI is a pure Python port of the MySQL-python (MySQLdb) driver,
|
||||
and targets 100% compatibility. Most behavioral notes for MySQL-python apply
|
||||
to the pymysql driver as well.
|
||||
|
||||
""" # noqa
|
||||
|
||||
from .mysqldb import MySQLDialect_mysqldb
|
||||
from ...util import langhelpers
|
||||
|
||||
|
||||
class MySQLDialect_pymysql(MySQLDialect_mysqldb):
|
||||
driver = "pymysql"
|
||||
supports_statement_cache = True
|
||||
|
||||
description_encoding = None
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def supports_server_side_cursors(self):
|
||||
try:
|
||||
cursors = __import__("pymysql.cursors").cursors
|
||||
self._sscursor = cursors.SSCursor
|
||||
return True
|
||||
except (ImportError, AttributeError):
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("pymysql")
|
||||
|
||||
@langhelpers.memoized_property
|
||||
def _send_false_to_ping(self):
|
||||
"""determine if pymysql has deprecated, changed the default of,
|
||||
or removed the 'reconnect' argument of connection.ping().
|
||||
|
||||
See #10492 and
|
||||
https://github.com/PyMySQL/mysqlclient/discussions/651#discussioncomment-7308971
|
||||
for background.
|
||||
|
||||
""" # noqa: E501
|
||||
|
||||
try:
|
||||
Connection = __import__(
|
||||
"pymysql.connections"
|
||||
).connections.Connection
|
||||
except (ImportError, AttributeError):
|
||||
return True
|
||||
else:
|
||||
insp = langhelpers.get_callable_argspec(Connection.ping)
|
||||
try:
|
||||
reconnect_arg = insp.args[1]
|
||||
except IndexError:
|
||||
return False
|
||||
else:
|
||||
return reconnect_arg == "reconnect" and (
|
||||
not insp.defaults or insp.defaults[0] is not False
|
||||
)
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
if self._send_false_to_ping:
|
||||
dbapi_connection.ping(False)
|
||||
else:
|
||||
dbapi_connection.ping()
|
||||
|
||||
return True
|
||||
|
||||
def create_connect_args(self, url, _translate_args=None):
|
||||
if _translate_args is None:
|
||||
_translate_args = dict(username="user")
|
||||
return super().create_connect_args(
|
||||
url, _translate_args=_translate_args
|
||||
)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if super().is_disconnect(e, connection, cursor):
|
||||
return True
|
||||
elif isinstance(e, self.dbapi.Error):
|
||||
str_e = str(e).lower()
|
||||
return (
|
||||
"already closed" in str_e or "connection was killed" in str_e
|
||||
)
|
||||
else:
|
||||
return False
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
if isinstance(exception.args[0], Exception):
|
||||
exception = exception.args[0]
|
||||
return exception.args[0]
|
||||
|
||||
|
||||
dialect = MySQLDialect_pymysql
|
|
@ -0,0 +1,138 @@
|
|||
# dialects/mysql/pyodbc.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
|
||||
.. dialect:: mysql+pyodbc
|
||||
:name: PyODBC
|
||||
:dbapi: pyodbc
|
||||
:connectstring: mysql+pyodbc://<username>:<password>@<dsnname>
|
||||
:url: https://pypi.org/project/pyodbc/
|
||||
|
||||
.. note::
|
||||
|
||||
The PyODBC for MySQL dialect is **not tested as part of
|
||||
SQLAlchemy's continuous integration**.
|
||||
The recommended MySQL dialects are mysqlclient and PyMySQL.
|
||||
However, if you want to use the mysql+pyodbc dialect and require
|
||||
full support for ``utf8mb4`` characters (including supplementary
|
||||
characters like emoji) be sure to use a current release of
|
||||
MySQL Connector/ODBC and specify the "ANSI" (**not** "Unicode")
|
||||
version of the driver in your DSN or connection string.
|
||||
|
||||
Pass through exact pyodbc connection string::
|
||||
|
||||
import urllib
|
||||
connection_string = (
|
||||
'DRIVER=MySQL ODBC 8.0 ANSI Driver;'
|
||||
'SERVER=localhost;'
|
||||
'PORT=3307;'
|
||||
'DATABASE=mydb;'
|
||||
'UID=root;'
|
||||
'PWD=(whatever);'
|
||||
'charset=utf8mb4;'
|
||||
)
|
||||
params = urllib.parse.quote_plus(connection_string)
|
||||
connection_uri = "mysql+pyodbc:///?odbc_connect=%s" % params
|
||||
|
||||
""" # noqa
|
||||
|
||||
import re
|
||||
|
||||
from .base import MySQLDialect
|
||||
from .base import MySQLExecutionContext
|
||||
from .types import TIME
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...connectors.pyodbc import PyODBCConnector
|
||||
from ...sql.sqltypes import Time
|
||||
|
||||
|
||||
class _pyodbcTIME(TIME):
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
# pyodbc returns a datetime.time object; no need to convert
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class MySQLExecutionContext_pyodbc(MySQLExecutionContext):
|
||||
def get_lastrowid(self):
|
||||
cursor = self.create_cursor()
|
||||
cursor.execute("SELECT LAST_INSERT_ID()")
|
||||
lastrowid = cursor.fetchone()[0]
|
||||
cursor.close()
|
||||
return lastrowid
|
||||
|
||||
|
||||
class MySQLDialect_pyodbc(PyODBCConnector, MySQLDialect):
|
||||
supports_statement_cache = True
|
||||
colspecs = util.update_copy(MySQLDialect.colspecs, {Time: _pyodbcTIME})
|
||||
supports_unicode_statements = True
|
||||
execution_ctx_cls = MySQLExecutionContext_pyodbc
|
||||
|
||||
pyodbc_driver_name = "MySQL"
|
||||
|
||||
def _detect_charset(self, connection):
|
||||
"""Sniff out the character set in use for connection results."""
|
||||
|
||||
# Prefer 'character_set_results' for the current connection over the
|
||||
# value in the driver. SET NAMES or individual variable SETs will
|
||||
# change the charset without updating the driver's view of the world.
|
||||
#
|
||||
# If it's decided that issuing that sort of SQL leaves you SOL, then
|
||||
# this can prefer the driver value.
|
||||
|
||||
# set this to None as _fetch_setting attempts to use it (None is OK)
|
||||
self._connection_charset = None
|
||||
try:
|
||||
value = self._fetch_setting(connection, "character_set_client")
|
||||
if value:
|
||||
return value
|
||||
except exc.DBAPIError:
|
||||
pass
|
||||
|
||||
util.warn(
|
||||
"Could not detect the connection character set. "
|
||||
"Assuming latin1."
|
||||
)
|
||||
return "latin1"
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return MySQLDialect._get_server_version_info(self, connection)
|
||||
|
||||
def _extract_error_code(self, exception):
|
||||
m = re.compile(r"\((\d+)\)").search(str(exception.args))
|
||||
c = m.group(1)
|
||||
if c:
|
||||
return int(c)
|
||||
else:
|
||||
return None
|
||||
|
||||
def on_connect(self):
|
||||
super_ = super().on_connect()
|
||||
|
||||
def on_connect(conn):
|
||||
if super_ is not None:
|
||||
super_(conn)
|
||||
|
||||
# declare Unicode encoding for pyodbc as per
|
||||
# https://github.com/mkleehammer/pyodbc/wiki/Unicode
|
||||
pyodbc_SQL_CHAR = 1 # pyodbc.SQL_CHAR
|
||||
pyodbc_SQL_WCHAR = -8 # pyodbc.SQL_WCHAR
|
||||
conn.setdecoding(pyodbc_SQL_CHAR, encoding="utf-8")
|
||||
conn.setdecoding(pyodbc_SQL_WCHAR, encoding="utf-8")
|
||||
conn.setencoding(encoding="utf-8")
|
||||
|
||||
return on_connect
|
||||
|
||||
|
||||
dialect = MySQLDialect_pyodbc
|
|
@ -0,0 +1,677 @@
|
|||
# dialects/mysql/reflection.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from .enumerated import ENUM
|
||||
from .enumerated import SET
|
||||
from .types import DATETIME
|
||||
from .types import TIME
|
||||
from .types import TIMESTAMP
|
||||
from ... import log
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
|
||||
|
||||
class ReflectedState:
|
||||
"""Stores raw information about a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(self):
|
||||
self.columns = []
|
||||
self.table_options = {}
|
||||
self.table_name = None
|
||||
self.keys = []
|
||||
self.fk_constraints = []
|
||||
self.ck_constraints = []
|
||||
|
||||
|
||||
@log.class_logger
|
||||
class MySQLTableDefinitionParser:
|
||||
"""Parses the results of a SHOW CREATE TABLE statement."""
|
||||
|
||||
def __init__(self, dialect, preparer):
|
||||
self.dialect = dialect
|
||||
self.preparer = preparer
|
||||
self._prep_regexes()
|
||||
|
||||
def parse(self, show_create, charset):
|
||||
state = ReflectedState()
|
||||
state.charset = charset
|
||||
for line in re.split(r"\r?\n", show_create):
|
||||
if line.startswith(" " + self.preparer.initial_quote):
|
||||
self._parse_column(line, state)
|
||||
# a regular table options line
|
||||
elif line.startswith(") "):
|
||||
self._parse_table_options(line, state)
|
||||
# an ANSI-mode table options line
|
||||
elif line == ")":
|
||||
pass
|
||||
elif line.startswith("CREATE "):
|
||||
self._parse_table_name(line, state)
|
||||
elif "PARTITION" in line:
|
||||
self._parse_partition_options(line, state)
|
||||
# Not present in real reflection, but may be if
|
||||
# loading from a file.
|
||||
elif not line:
|
||||
pass
|
||||
else:
|
||||
type_, spec = self._parse_constraints(line)
|
||||
if type_ is None:
|
||||
util.warn("Unknown schema content: %r" % line)
|
||||
elif type_ == "key":
|
||||
state.keys.append(spec)
|
||||
elif type_ == "fk_constraint":
|
||||
state.fk_constraints.append(spec)
|
||||
elif type_ == "ck_constraint":
|
||||
state.ck_constraints.append(spec)
|
||||
else:
|
||||
pass
|
||||
return state
|
||||
|
||||
def _check_view(self, sql: str) -> bool:
|
||||
return bool(self._re_is_view.match(sql))
|
||||
|
||||
def _parse_constraints(self, line):
|
||||
"""Parse a KEY or CONSTRAINT line.
|
||||
|
||||
:param line: A line of SHOW CREATE TABLE output
|
||||
"""
|
||||
|
||||
# KEY
|
||||
m = self._re_key.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
# convert columns into name, length pairs
|
||||
# NOTE: we may want to consider SHOW INDEX as the
|
||||
# format of indexes in MySQL becomes more complex
|
||||
spec["columns"] = self._parse_keyexprs(spec["columns"])
|
||||
if spec["version_sql"]:
|
||||
m2 = self._re_key_version_sql.match(spec["version_sql"])
|
||||
if m2 and m2.groupdict()["parser"]:
|
||||
spec["parser"] = m2.groupdict()["parser"]
|
||||
if spec["parser"]:
|
||||
spec["parser"] = self.preparer.unformat_identifiers(
|
||||
spec["parser"]
|
||||
)[0]
|
||||
return "key", spec
|
||||
|
||||
# FOREIGN KEY CONSTRAINT
|
||||
m = self._re_fk_constraint.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec["table"] = self.preparer.unformat_identifiers(spec["table"])
|
||||
spec["local"] = [c[0] for c in self._parse_keyexprs(spec["local"])]
|
||||
spec["foreign"] = [
|
||||
c[0] for c in self._parse_keyexprs(spec["foreign"])
|
||||
]
|
||||
return "fk_constraint", spec
|
||||
|
||||
# CHECK constraint
|
||||
m = self._re_ck_constraint.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
return "ck_constraint", spec
|
||||
|
||||
# PARTITION and SUBPARTITION
|
||||
m = self._re_partition.match(line)
|
||||
if m:
|
||||
# Punt!
|
||||
return "partition", line
|
||||
|
||||
# No match.
|
||||
return (None, line)
|
||||
|
||||
def _parse_table_name(self, line, state):
|
||||
"""Extract the table name.
|
||||
|
||||
:param line: The first line of SHOW CREATE TABLE
|
||||
"""
|
||||
|
||||
regex, cleanup = self._pr_name
|
||||
m = regex.match(line)
|
||||
if m:
|
||||
state.table_name = cleanup(m.group("name"))
|
||||
|
||||
def _parse_table_options(self, line, state):
|
||||
"""Build a dictionary of all reflected table-level options.
|
||||
|
||||
:param line: The final line of SHOW CREATE TABLE output.
|
||||
"""
|
||||
|
||||
options = {}
|
||||
|
||||
if line and line != ")":
|
||||
rest_of_line = line
|
||||
for regex, cleanup in self._pr_options:
|
||||
m = regex.search(rest_of_line)
|
||||
if not m:
|
||||
continue
|
||||
directive, value = m.group("directive"), m.group("val")
|
||||
if cleanup:
|
||||
value = cleanup(value)
|
||||
options[directive.lower()] = value
|
||||
rest_of_line = regex.sub("", rest_of_line)
|
||||
|
||||
for nope in ("auto_increment", "data directory", "index directory"):
|
||||
options.pop(nope, None)
|
||||
|
||||
for opt, val in options.items():
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_partition_options(self, line, state):
|
||||
options = {}
|
||||
new_line = line[:]
|
||||
|
||||
while new_line.startswith("(") or new_line.startswith(" "):
|
||||
new_line = new_line[1:]
|
||||
|
||||
for regex, cleanup in self._pr_options:
|
||||
m = regex.search(new_line)
|
||||
if not m or "PARTITION" not in regex.pattern:
|
||||
continue
|
||||
|
||||
directive = m.group("directive")
|
||||
directive = directive.lower()
|
||||
is_subpartition = directive == "subpartition"
|
||||
|
||||
if directive == "partition" or is_subpartition:
|
||||
new_line = new_line.replace(") */", "")
|
||||
new_line = new_line.replace(",", "")
|
||||
if is_subpartition and new_line.endswith(")"):
|
||||
new_line = new_line[:-1]
|
||||
if self.dialect.name == "mariadb" and new_line.endswith(")"):
|
||||
if (
|
||||
"MAXVALUE" in new_line
|
||||
or "MINVALUE" in new_line
|
||||
or "ENGINE" in new_line
|
||||
):
|
||||
# final line of MariaDB partition endswith ")"
|
||||
new_line = new_line[:-1]
|
||||
|
||||
defs = "%s_%s_definitions" % (self.dialect.name, directive)
|
||||
options[defs] = new_line
|
||||
|
||||
else:
|
||||
directive = directive.replace(" ", "_")
|
||||
value = m.group("val")
|
||||
if cleanup:
|
||||
value = cleanup(value)
|
||||
options[directive] = value
|
||||
break
|
||||
|
||||
for opt, val in options.items():
|
||||
part_def = "%s_partition_definitions" % (self.dialect.name)
|
||||
subpart_def = "%s_subpartition_definitions" % (self.dialect.name)
|
||||
if opt == part_def or opt == subpart_def:
|
||||
# builds a string of definitions
|
||||
if opt not in state.table_options:
|
||||
state.table_options[opt] = val
|
||||
else:
|
||||
state.table_options[opt] = "%s, %s" % (
|
||||
state.table_options[opt],
|
||||
val,
|
||||
)
|
||||
else:
|
||||
state.table_options["%s_%s" % (self.dialect.name, opt)] = val
|
||||
|
||||
def _parse_column(self, line, state):
|
||||
"""Extract column details.
|
||||
|
||||
Falls back to a 'minimal support' variant if full parse fails.
|
||||
|
||||
:param line: Any column-bearing line from SHOW CREATE TABLE
|
||||
"""
|
||||
|
||||
spec = None
|
||||
m = self._re_column.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec["full"] = True
|
||||
else:
|
||||
m = self._re_column_loose.match(line)
|
||||
if m:
|
||||
spec = m.groupdict()
|
||||
spec["full"] = False
|
||||
if not spec:
|
||||
util.warn("Unknown column definition %r" % line)
|
||||
return
|
||||
if not spec["full"]:
|
||||
util.warn("Incomplete reflection of column definition %r" % line)
|
||||
|
||||
name, type_, args = spec["name"], spec["coltype"], spec["arg"]
|
||||
|
||||
try:
|
||||
col_type = self.dialect.ischema_names[type_]
|
||||
except KeyError:
|
||||
util.warn(
|
||||
"Did not recognize type '%s' of column '%s'" % (type_, name)
|
||||
)
|
||||
col_type = sqltypes.NullType
|
||||
|
||||
# Column type positional arguments eg. varchar(32)
|
||||
if args is None or args == "":
|
||||
type_args = []
|
||||
elif args[0] == "'" and args[-1] == "'":
|
||||
type_args = self._re_csv_str.findall(args)
|
||||
else:
|
||||
type_args = [int(v) for v in self._re_csv_int.findall(args)]
|
||||
|
||||
# Column type keyword options
|
||||
type_kw = {}
|
||||
|
||||
if issubclass(col_type, (DATETIME, TIME, TIMESTAMP)):
|
||||
if type_args:
|
||||
type_kw["fsp"] = type_args.pop(0)
|
||||
|
||||
for kw in ("unsigned", "zerofill"):
|
||||
if spec.get(kw, False):
|
||||
type_kw[kw] = True
|
||||
for kw in ("charset", "collate"):
|
||||
if spec.get(kw, False):
|
||||
type_kw[kw] = spec[kw]
|
||||
if issubclass(col_type, (ENUM, SET)):
|
||||
type_args = _strip_values(type_args)
|
||||
|
||||
if issubclass(col_type, SET) and "" in type_args:
|
||||
type_kw["retrieve_as_bitwise"] = True
|
||||
|
||||
type_instance = col_type(*type_args, **type_kw)
|
||||
|
||||
col_kw = {}
|
||||
|
||||
# NOT NULL
|
||||
col_kw["nullable"] = True
|
||||
# this can be "NULL" in the case of TIMESTAMP
|
||||
if spec.get("notnull", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
# For generated columns, the nullability is marked in a different place
|
||||
if spec.get("notnull_generated", False) == "NOT NULL":
|
||||
col_kw["nullable"] = False
|
||||
|
||||
# AUTO_INCREMENT
|
||||
if spec.get("autoincr", False):
|
||||
col_kw["autoincrement"] = True
|
||||
elif issubclass(col_type, sqltypes.Integer):
|
||||
col_kw["autoincrement"] = False
|
||||
|
||||
# DEFAULT
|
||||
default = spec.get("default", None)
|
||||
|
||||
if default == "NULL":
|
||||
# eliminates the need to deal with this later.
|
||||
default = None
|
||||
|
||||
comment = spec.get("comment", None)
|
||||
|
||||
if comment is not None:
|
||||
comment = cleanup_text(comment)
|
||||
|
||||
sqltext = spec.get("generated")
|
||||
if sqltext is not None:
|
||||
computed = dict(sqltext=sqltext)
|
||||
persisted = spec.get("persistence")
|
||||
if persisted is not None:
|
||||
computed["persisted"] = persisted == "STORED"
|
||||
col_kw["computed"] = computed
|
||||
|
||||
col_d = dict(
|
||||
name=name, type=type_instance, default=default, comment=comment
|
||||
)
|
||||
col_d.update(col_kw)
|
||||
state.columns.append(col_d)
|
||||
|
||||
def _describe_to_create(self, table_name, columns):
|
||||
"""Re-format DESCRIBE output as a SHOW CREATE TABLE string.
|
||||
|
||||
DESCRIBE is a much simpler reflection and is sufficient for
|
||||
reflecting views for runtime use. This method formats DDL
|
||||
for columns only- keys are omitted.
|
||||
|
||||
:param columns: A sequence of DESCRIBE or SHOW COLUMNS 6-tuples.
|
||||
SHOW FULL COLUMNS FROM rows must be rearranged for use with
|
||||
this function.
|
||||
"""
|
||||
|
||||
buffer = []
|
||||
for row in columns:
|
||||
(name, col_type, nullable, default, extra) = (
|
||||
row[i] for i in (0, 1, 2, 4, 5)
|
||||
)
|
||||
|
||||
line = [" "]
|
||||
line.append(self.preparer.quote_identifier(name))
|
||||
line.append(col_type)
|
||||
if not nullable:
|
||||
line.append("NOT NULL")
|
||||
if default:
|
||||
if "auto_increment" in default:
|
||||
pass
|
||||
elif col_type.startswith("timestamp") and default.startswith(
|
||||
"C"
|
||||
):
|
||||
line.append("DEFAULT")
|
||||
line.append(default)
|
||||
elif default == "NULL":
|
||||
line.append("DEFAULT")
|
||||
line.append(default)
|
||||
else:
|
||||
line.append("DEFAULT")
|
||||
line.append("'%s'" % default.replace("'", "''"))
|
||||
if extra:
|
||||
line.append(extra)
|
||||
|
||||
buffer.append(" ".join(line))
|
||||
|
||||
return "".join(
|
||||
[
|
||||
(
|
||||
"CREATE TABLE %s (\n"
|
||||
% self.preparer.quote_identifier(table_name)
|
||||
),
|
||||
",\n".join(buffer),
|
||||
"\n) ",
|
||||
]
|
||||
)
|
||||
|
||||
def _parse_keyexprs(self, identifiers):
|
||||
"""Unpack '"col"(2),"col" ASC'-ish strings into components."""
|
||||
|
||||
return [
|
||||
(colname, int(length) if length else None, modifiers)
|
||||
for colname, length, modifiers in self._re_keyexprs.findall(
|
||||
identifiers
|
||||
)
|
||||
]
|
||||
|
||||
def _prep_regexes(self):
|
||||
"""Pre-compile regular expressions."""
|
||||
|
||||
self._re_columns = []
|
||||
self._pr_options = []
|
||||
|
||||
_final = self.preparer.final_quote
|
||||
|
||||
quotes = dict(
|
||||
zip(
|
||||
("iq", "fq", "esc_fq"),
|
||||
[
|
||||
re.escape(s)
|
||||
for s in (
|
||||
self.preparer.initial_quote,
|
||||
_final,
|
||||
self.preparer._escape_identifier(_final),
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
|
||||
self._pr_name = _pr_compile(
|
||||
r"^CREATE (?:\w+ +)?TABLE +"
|
||||
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +\($" % quotes,
|
||||
self.preparer._unescape_identifier,
|
||||
)
|
||||
|
||||
self._re_is_view = _re_compile(r"^CREATE(?! TABLE)(\s.*)?\sVIEW")
|
||||
|
||||
# `col`,`col2`(32),`col3`(15) DESC
|
||||
#
|
||||
self._re_keyexprs = _re_compile(
|
||||
r"(?:"
|
||||
r"(?:%(iq)s((?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)"
|
||||
r"(?:\((\d+)\))?(?: +(ASC|DESC))?(?=\,|$))+" % quotes
|
||||
)
|
||||
|
||||
# 'foo' or 'foo','bar' or 'fo,o','ba''a''r'
|
||||
self._re_csv_str = _re_compile(r"\x27(?:\x27\x27|[^\x27])*\x27")
|
||||
|
||||
# 123 or 123,456
|
||||
self._re_csv_int = _re_compile(r"\d+")
|
||||
|
||||
# `colname` <type> [type opts]
|
||||
# (NOT NULL | NULL)
|
||||
# DEFAULT ('value' | CURRENT_TIMESTAMP...)
|
||||
# COMMENT 'comment'
|
||||
# COLUMN_FORMAT (FIXED|DYNAMIC|DEFAULT)
|
||||
# STORAGE (DISK|MEMORY)
|
||||
self._re_column = _re_compile(
|
||||
r" "
|
||||
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
|
||||
r"(?P<coltype>\w+)"
|
||||
r"(?:\((?P<arg>(?:\d+|\d+,\d+|"
|
||||
r"(?:'(?:''|[^'])*',?)+))\))?"
|
||||
r"(?: +(?P<unsigned>UNSIGNED))?"
|
||||
r"(?: +(?P<zerofill>ZEROFILL))?"
|
||||
r"(?: +CHARACTER SET +(?P<charset>[\w_]+))?"
|
||||
r"(?: +COLLATE +(?P<collate>[\w_]+))?"
|
||||
r"(?: +(?P<notnull>(?:NOT )?NULL))?"
|
||||
r"(?: +DEFAULT +(?P<default>"
|
||||
r"(?:NULL|'(?:''|[^'])*'|[\-\w\.\(\)]+"
|
||||
r"(?: +ON UPDATE [\-\w\.\(\)]+)?)"
|
||||
r"))?"
|
||||
r"(?: +(?:GENERATED ALWAYS)? ?AS +(?P<generated>\("
|
||||
r".*\))? ?(?P<persistence>VIRTUAL|STORED)?"
|
||||
r"(?: +(?P<notnull_generated>(?:NOT )?NULL))?"
|
||||
r")?"
|
||||
r"(?: +(?P<autoincr>AUTO_INCREMENT))?"
|
||||
r"(?: +COMMENT +'(?P<comment>(?:''|[^'])*)')?"
|
||||
r"(?: +COLUMN_FORMAT +(?P<colfmt>\w+))?"
|
||||
r"(?: +STORAGE +(?P<storage>\w+))?"
|
||||
r"(?: +(?P<extra>.*))?"
|
||||
r",?$" % quotes
|
||||
)
|
||||
|
||||
# Fallback, try to parse as little as possible
|
||||
self._re_column_loose = _re_compile(
|
||||
r" "
|
||||
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
|
||||
r"(?P<coltype>\w+)"
|
||||
r"(?:\((?P<arg>(?:\d+|\d+,\d+|\x27(?:\x27\x27|[^\x27])+\x27))\))?"
|
||||
r".*?(?P<notnull>(?:NOT )NULL)?" % quotes
|
||||
)
|
||||
|
||||
# (PRIMARY|UNIQUE|FULLTEXT|SPATIAL) INDEX `name` (USING (BTREE|HASH))?
|
||||
# (`col` (ASC|DESC)?, `col` (ASC|DESC)?)
|
||||
# KEY_BLOCK_SIZE size | WITH PARSER name /*!50100 WITH PARSER name */
|
||||
self._re_key = _re_compile(
|
||||
r" "
|
||||
r"(?:(?P<type>\S+) )?KEY"
|
||||
r"(?: +%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s)?"
|
||||
r"(?: +USING +(?P<using_pre>\S+))?"
|
||||
r" +\((?P<columns>.+?)\)"
|
||||
r"(?: +USING +(?P<using_post>\S+))?"
|
||||
r"(?: +KEY_BLOCK_SIZE *[ =]? *(?P<keyblock>\S+))?"
|
||||
r"(?: +WITH PARSER +(?P<parser>\S+))?"
|
||||
r"(?: +COMMENT +(?P<comment>(\x27\x27|\x27([^\x27])*?\x27)+))?"
|
||||
r"(?: +/\*(?P<version_sql>.+)\*/ *)?"
|
||||
r",?$" % quotes
|
||||
)
|
||||
|
||||
# https://forums.mysql.com/read.php?20,567102,567111#msg-567111
|
||||
# It means if the MySQL version >= \d+, execute what's in the comment
|
||||
self._re_key_version_sql = _re_compile(
|
||||
r"\!\d+ " r"(?: *WITH PARSER +(?P<parser>\S+) *)?"
|
||||
)
|
||||
|
||||
# CONSTRAINT `name` FOREIGN KEY (`local_col`)
|
||||
# REFERENCES `remote` (`remote_col`)
|
||||
# MATCH FULL | MATCH PARTIAL | MATCH SIMPLE
|
||||
# ON DELETE CASCADE ON UPDATE RESTRICT
|
||||
#
|
||||
# unique constraints come back as KEYs
|
||||
kw = quotes.copy()
|
||||
kw["on"] = "RESTRICT|CASCADE|SET NULL|NO ACTION"
|
||||
self._re_fk_constraint = _re_compile(
|
||||
r" "
|
||||
r"CONSTRAINT +"
|
||||
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
|
||||
r"FOREIGN KEY +"
|
||||
r"\((?P<local>[^\)]+?)\) REFERENCES +"
|
||||
r"(?P<table>%(iq)s[^%(fq)s]+%(fq)s"
|
||||
r"(?:\.%(iq)s[^%(fq)s]+%(fq)s)?) +"
|
||||
r"\((?P<foreign>(?:%(iq)s[^%(fq)s]+%(fq)s(?: *, *)?)+)\)"
|
||||
r"(?: +(?P<match>MATCH \w+))?"
|
||||
r"(?: +ON DELETE (?P<ondelete>%(on)s))?"
|
||||
r"(?: +ON UPDATE (?P<onupdate>%(on)s))?" % kw
|
||||
)
|
||||
|
||||
# CONSTRAINT `CONSTRAINT_1` CHECK (`x` > 5)'
|
||||
# testing on MariaDB 10.2 shows that the CHECK constraint
|
||||
# is returned on a line by itself, so to match without worrying
|
||||
# about parenthesis in the expression we go to the end of the line
|
||||
self._re_ck_constraint = _re_compile(
|
||||
r" "
|
||||
r"CONSTRAINT +"
|
||||
r"%(iq)s(?P<name>(?:%(esc_fq)s|[^%(fq)s])+)%(fq)s +"
|
||||
r"CHECK +"
|
||||
r"\((?P<sqltext>.+)\),?" % kw
|
||||
)
|
||||
|
||||
# PARTITION
|
||||
#
|
||||
# punt!
|
||||
self._re_partition = _re_compile(r"(?:.*)(?:SUB)?PARTITION(?:.*)")
|
||||
|
||||
# Table-level options (COLLATE, ENGINE, etc.)
|
||||
# Do the string options first, since they have quoted
|
||||
# strings we need to get rid of.
|
||||
for option in _options_of_type_string:
|
||||
self._add_option_string(option)
|
||||
|
||||
for option in (
|
||||
"ENGINE",
|
||||
"TYPE",
|
||||
"AUTO_INCREMENT",
|
||||
"AVG_ROW_LENGTH",
|
||||
"CHARACTER SET",
|
||||
"DEFAULT CHARSET",
|
||||
"CHECKSUM",
|
||||
"COLLATE",
|
||||
"DELAY_KEY_WRITE",
|
||||
"INSERT_METHOD",
|
||||
"MAX_ROWS",
|
||||
"MIN_ROWS",
|
||||
"PACK_KEYS",
|
||||
"ROW_FORMAT",
|
||||
"KEY_BLOCK_SIZE",
|
||||
"STATS_SAMPLE_PAGES",
|
||||
):
|
||||
self._add_option_word(option)
|
||||
|
||||
for option in (
|
||||
"PARTITION BY",
|
||||
"SUBPARTITION BY",
|
||||
"PARTITIONS",
|
||||
"SUBPARTITIONS",
|
||||
"PARTITION",
|
||||
"SUBPARTITION",
|
||||
):
|
||||
self._add_partition_option_word(option)
|
||||
|
||||
self._add_option_regex("UNION", r"\([^\)]+\)")
|
||||
self._add_option_regex("TABLESPACE", r".*? STORAGE DISK")
|
||||
self._add_option_regex(
|
||||
"RAID_TYPE",
|
||||
r"\w+\s+RAID_CHUNKS\s*\=\s*\w+RAID_CHUNKSIZE\s*=\s*\w+",
|
||||
)
|
||||
|
||||
_optional_equals = r"(?:\s*(?:=\s*)|\s+)"
|
||||
|
||||
def _add_option_string(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"'(?P<val>(?:[^']|'')*?)'(?!')" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex, cleanup_text))
|
||||
|
||||
def _add_option_word(self, directive):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>\w+)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_partition_option_word(self, directive):
|
||||
if directive == "PARTITION BY" or directive == "SUBPARTITION BY":
|
||||
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\w+.*)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
elif directive == "SUBPARTITIONS" or directive == "PARTITIONS":
|
||||
regex = r"(?<!\S)(?P<directive>%s)%s" r"(?P<val>\d+)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
)
|
||||
else:
|
||||
regex = r"(?<!\S)(?P<directive>%s)(?!\S)" % (re.escape(directive),)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
def _add_option_regex(self, directive, regex):
|
||||
regex = r"(?P<directive>%s)%s" r"(?P<val>%s)" % (
|
||||
re.escape(directive),
|
||||
self._optional_equals,
|
||||
regex,
|
||||
)
|
||||
self._pr_options.append(_pr_compile(regex))
|
||||
|
||||
|
||||
_options_of_type_string = (
|
||||
"COMMENT",
|
||||
"DATA DIRECTORY",
|
||||
"INDEX DIRECTORY",
|
||||
"PASSWORD",
|
||||
"CONNECTION",
|
||||
)
|
||||
|
||||
|
||||
def _pr_compile(regex, cleanup=None):
|
||||
"""Prepare a 2-tuple of compiled regex and callable."""
|
||||
|
||||
return (_re_compile(regex), cleanup)
|
||||
|
||||
|
||||
def _re_compile(regex):
|
||||
"""Compile a string to regex, I and UNICODE."""
|
||||
|
||||
return re.compile(regex, re.I | re.UNICODE)
|
||||
|
||||
|
||||
def _strip_values(values):
|
||||
"Strip reflected values quotes"
|
||||
strip_values = []
|
||||
for a in values:
|
||||
if a[0:1] == '"' or a[0:1] == "'":
|
||||
# strip enclosing quotes and unquote interior
|
||||
a = a[1:-1].replace(a[0] * 2, a[0])
|
||||
strip_values.append(a)
|
||||
return strip_values
|
||||
|
||||
|
||||
def cleanup_text(raw_text: str) -> str:
|
||||
if "\\" in raw_text:
|
||||
raw_text = re.sub(
|
||||
_control_char_regexp, lambda s: _control_char_map[s[0]], raw_text
|
||||
)
|
||||
return raw_text.replace("''", "'")
|
||||
|
||||
|
||||
_control_char_map = {
|
||||
"\\\\": "\\",
|
||||
"\\0": "\0",
|
||||
"\\a": "\a",
|
||||
"\\b": "\b",
|
||||
"\\t": "\t",
|
||||
"\\n": "\n",
|
||||
"\\v": "\v",
|
||||
"\\f": "\f",
|
||||
"\\r": "\r",
|
||||
# '\\e':'\e',
|
||||
}
|
||||
_control_char_regexp = re.compile(
|
||||
"|".join(re.escape(k) for k in _control_char_map)
|
||||
)
|
|
@ -0,0 +1,567 @@
|
|||
# dialects/mysql/reserved_words.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
# generated using:
|
||||
# https://gist.github.com/kkirsche/4f31f2153ed7a3248be1ec44ca6ddbc9
|
||||
#
|
||||
# https://mariadb.com/kb/en/reserved-words/
|
||||
# includes: Reserved Words, Oracle Mode (separate set unioned)
|
||||
# excludes: Exceptions, Function Names
|
||||
# mypy: ignore-errors
|
||||
|
||||
RESERVED_WORDS_MARIADB = {
|
||||
"accessible",
|
||||
"add",
|
||||
"all",
|
||||
"alter",
|
||||
"analyze",
|
||||
"and",
|
||||
"as",
|
||||
"asc",
|
||||
"asensitive",
|
||||
"before",
|
||||
"between",
|
||||
"bigint",
|
||||
"binary",
|
||||
"blob",
|
||||
"both",
|
||||
"by",
|
||||
"call",
|
||||
"cascade",
|
||||
"case",
|
||||
"change",
|
||||
"char",
|
||||
"character",
|
||||
"check",
|
||||
"collate",
|
||||
"column",
|
||||
"condition",
|
||||
"constraint",
|
||||
"continue",
|
||||
"convert",
|
||||
"create",
|
||||
"cross",
|
||||
"current_date",
|
||||
"current_role",
|
||||
"current_time",
|
||||
"current_timestamp",
|
||||
"current_user",
|
||||
"cursor",
|
||||
"database",
|
||||
"databases",
|
||||
"day_hour",
|
||||
"day_microsecond",
|
||||
"day_minute",
|
||||
"day_second",
|
||||
"dec",
|
||||
"decimal",
|
||||
"declare",
|
||||
"default",
|
||||
"delayed",
|
||||
"delete",
|
||||
"desc",
|
||||
"describe",
|
||||
"deterministic",
|
||||
"distinct",
|
||||
"distinctrow",
|
||||
"div",
|
||||
"do_domain_ids",
|
||||
"double",
|
||||
"drop",
|
||||
"dual",
|
||||
"each",
|
||||
"else",
|
||||
"elseif",
|
||||
"enclosed",
|
||||
"escaped",
|
||||
"except",
|
||||
"exists",
|
||||
"exit",
|
||||
"explain",
|
||||
"false",
|
||||
"fetch",
|
||||
"float",
|
||||
"float4",
|
||||
"float8",
|
||||
"for",
|
||||
"force",
|
||||
"foreign",
|
||||
"from",
|
||||
"fulltext",
|
||||
"general",
|
||||
"grant",
|
||||
"group",
|
||||
"having",
|
||||
"high_priority",
|
||||
"hour_microsecond",
|
||||
"hour_minute",
|
||||
"hour_second",
|
||||
"if",
|
||||
"ignore",
|
||||
"ignore_domain_ids",
|
||||
"ignore_server_ids",
|
||||
"in",
|
||||
"index",
|
||||
"infile",
|
||||
"inner",
|
||||
"inout",
|
||||
"insensitive",
|
||||
"insert",
|
||||
"int",
|
||||
"int1",
|
||||
"int2",
|
||||
"int3",
|
||||
"int4",
|
||||
"int8",
|
||||
"integer",
|
||||
"intersect",
|
||||
"interval",
|
||||
"into",
|
||||
"is",
|
||||
"iterate",
|
||||
"join",
|
||||
"key",
|
||||
"keys",
|
||||
"kill",
|
||||
"leading",
|
||||
"leave",
|
||||
"left",
|
||||
"like",
|
||||
"limit",
|
||||
"linear",
|
||||
"lines",
|
||||
"load",
|
||||
"localtime",
|
||||
"localtimestamp",
|
||||
"lock",
|
||||
"long",
|
||||
"longblob",
|
||||
"longtext",
|
||||
"loop",
|
||||
"low_priority",
|
||||
"master_heartbeat_period",
|
||||
"master_ssl_verify_server_cert",
|
||||
"match",
|
||||
"maxvalue",
|
||||
"mediumblob",
|
||||
"mediumint",
|
||||
"mediumtext",
|
||||
"middleint",
|
||||
"minute_microsecond",
|
||||
"minute_second",
|
||||
"mod",
|
||||
"modifies",
|
||||
"natural",
|
||||
"no_write_to_binlog",
|
||||
"not",
|
||||
"null",
|
||||
"numeric",
|
||||
"offset",
|
||||
"on",
|
||||
"optimize",
|
||||
"option",
|
||||
"optionally",
|
||||
"or",
|
||||
"order",
|
||||
"out",
|
||||
"outer",
|
||||
"outfile",
|
||||
"over",
|
||||
"page_checksum",
|
||||
"parse_vcol_expr",
|
||||
"partition",
|
||||
"position",
|
||||
"precision",
|
||||
"primary",
|
||||
"procedure",
|
||||
"purge",
|
||||
"range",
|
||||
"read",
|
||||
"read_write",
|
||||
"reads",
|
||||
"real",
|
||||
"recursive",
|
||||
"ref_system_id",
|
||||
"references",
|
||||
"regexp",
|
||||
"release",
|
||||
"rename",
|
||||
"repeat",
|
||||
"replace",
|
||||
"require",
|
||||
"resignal",
|
||||
"restrict",
|
||||
"return",
|
||||
"returning",
|
||||
"revoke",
|
||||
"right",
|
||||
"rlike",
|
||||
"rows",
|
||||
"row_number",
|
||||
"schema",
|
||||
"schemas",
|
||||
"second_microsecond",
|
||||
"select",
|
||||
"sensitive",
|
||||
"separator",
|
||||
"set",
|
||||
"show",
|
||||
"signal",
|
||||
"slow",
|
||||
"smallint",
|
||||
"spatial",
|
||||
"specific",
|
||||
"sql",
|
||||
"sql_big_result",
|
||||
"sql_calc_found_rows",
|
||||
"sql_small_result",
|
||||
"sqlexception",
|
||||
"sqlstate",
|
||||
"sqlwarning",
|
||||
"ssl",
|
||||
"starting",
|
||||
"stats_auto_recalc",
|
||||
"stats_persistent",
|
||||
"stats_sample_pages",
|
||||
"straight_join",
|
||||
"table",
|
||||
"terminated",
|
||||
"then",
|
||||
"tinyblob",
|
||||
"tinyint",
|
||||
"tinytext",
|
||||
"to",
|
||||
"trailing",
|
||||
"trigger",
|
||||
"true",
|
||||
"undo",
|
||||
"union",
|
||||
"unique",
|
||||
"unlock",
|
||||
"unsigned",
|
||||
"update",
|
||||
"usage",
|
||||
"use",
|
||||
"using",
|
||||
"utc_date",
|
||||
"utc_time",
|
||||
"utc_timestamp",
|
||||
"values",
|
||||
"varbinary",
|
||||
"varchar",
|
||||
"varcharacter",
|
||||
"varying",
|
||||
"when",
|
||||
"where",
|
||||
"while",
|
||||
"window",
|
||||
"with",
|
||||
"write",
|
||||
"xor",
|
||||
"year_month",
|
||||
"zerofill",
|
||||
}.union(
|
||||
{
|
||||
"body",
|
||||
"elsif",
|
||||
"goto",
|
||||
"history",
|
||||
"others",
|
||||
"package",
|
||||
"period",
|
||||
"raise",
|
||||
"rowtype",
|
||||
"system",
|
||||
"system_time",
|
||||
"versioning",
|
||||
"without",
|
||||
}
|
||||
)
|
||||
|
||||
# https://dev.mysql.com/doc/refman/8.0/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.7/en/keywords.html
|
||||
# https://dev.mysql.com/doc/refman/5.6/en/keywords.html
|
||||
# includes: MySQL x.0 Keywords and Reserved Words
|
||||
# excludes: MySQL x.0 New Keywords and Reserved Words,
|
||||
# MySQL x.0 Removed Keywords and Reserved Words
|
||||
RESERVED_WORDS_MYSQL = {
|
||||
"accessible",
|
||||
"add",
|
||||
"admin",
|
||||
"all",
|
||||
"alter",
|
||||
"analyze",
|
||||
"and",
|
||||
"array",
|
||||
"as",
|
||||
"asc",
|
||||
"asensitive",
|
||||
"before",
|
||||
"between",
|
||||
"bigint",
|
||||
"binary",
|
||||
"blob",
|
||||
"both",
|
||||
"by",
|
||||
"call",
|
||||
"cascade",
|
||||
"case",
|
||||
"change",
|
||||
"char",
|
||||
"character",
|
||||
"check",
|
||||
"collate",
|
||||
"column",
|
||||
"condition",
|
||||
"constraint",
|
||||
"continue",
|
||||
"convert",
|
||||
"create",
|
||||
"cross",
|
||||
"cube",
|
||||
"cume_dist",
|
||||
"current_date",
|
||||
"current_time",
|
||||
"current_timestamp",
|
||||
"current_user",
|
||||
"cursor",
|
||||
"database",
|
||||
"databases",
|
||||
"day_hour",
|
||||
"day_microsecond",
|
||||
"day_minute",
|
||||
"day_second",
|
||||
"dec",
|
||||
"decimal",
|
||||
"declare",
|
||||
"default",
|
||||
"delayed",
|
||||
"delete",
|
||||
"dense_rank",
|
||||
"desc",
|
||||
"describe",
|
||||
"deterministic",
|
||||
"distinct",
|
||||
"distinctrow",
|
||||
"div",
|
||||
"double",
|
||||
"drop",
|
||||
"dual",
|
||||
"each",
|
||||
"else",
|
||||
"elseif",
|
||||
"empty",
|
||||
"enclosed",
|
||||
"escaped",
|
||||
"except",
|
||||
"exists",
|
||||
"exit",
|
||||
"explain",
|
||||
"false",
|
||||
"fetch",
|
||||
"first_value",
|
||||
"float",
|
||||
"float4",
|
||||
"float8",
|
||||
"for",
|
||||
"force",
|
||||
"foreign",
|
||||
"from",
|
||||
"fulltext",
|
||||
"function",
|
||||
"general",
|
||||
"generated",
|
||||
"get",
|
||||
"get_master_public_key",
|
||||
"grant",
|
||||
"group",
|
||||
"grouping",
|
||||
"groups",
|
||||
"having",
|
||||
"high_priority",
|
||||
"hour_microsecond",
|
||||
"hour_minute",
|
||||
"hour_second",
|
||||
"if",
|
||||
"ignore",
|
||||
"ignore_server_ids",
|
||||
"in",
|
||||
"index",
|
||||
"infile",
|
||||
"inner",
|
||||
"inout",
|
||||
"insensitive",
|
||||
"insert",
|
||||
"int",
|
||||
"int1",
|
||||
"int2",
|
||||
"int3",
|
||||
"int4",
|
||||
"int8",
|
||||
"integer",
|
||||
"interval",
|
||||
"into",
|
||||
"io_after_gtids",
|
||||
"io_before_gtids",
|
||||
"is",
|
||||
"iterate",
|
||||
"join",
|
||||
"json_table",
|
||||
"key",
|
||||
"keys",
|
||||
"kill",
|
||||
"lag",
|
||||
"last_value",
|
||||
"lateral",
|
||||
"lead",
|
||||
"leading",
|
||||
"leave",
|
||||
"left",
|
||||
"like",
|
||||
"limit",
|
||||
"linear",
|
||||
"lines",
|
||||
"load",
|
||||
"localtime",
|
||||
"localtimestamp",
|
||||
"lock",
|
||||
"long",
|
||||
"longblob",
|
||||
"longtext",
|
||||
"loop",
|
||||
"low_priority",
|
||||
"master_bind",
|
||||
"master_heartbeat_period",
|
||||
"master_ssl_verify_server_cert",
|
||||
"match",
|
||||
"maxvalue",
|
||||
"mediumblob",
|
||||
"mediumint",
|
||||
"mediumtext",
|
||||
"member",
|
||||
"middleint",
|
||||
"minute_microsecond",
|
||||
"minute_second",
|
||||
"mod",
|
||||
"modifies",
|
||||
"natural",
|
||||
"no_write_to_binlog",
|
||||
"not",
|
||||
"nth_value",
|
||||
"ntile",
|
||||
"null",
|
||||
"numeric",
|
||||
"of",
|
||||
"on",
|
||||
"optimize",
|
||||
"optimizer_costs",
|
||||
"option",
|
||||
"optionally",
|
||||
"or",
|
||||
"order",
|
||||
"out",
|
||||
"outer",
|
||||
"outfile",
|
||||
"over",
|
||||
"parse_gcol_expr",
|
||||
"partition",
|
||||
"percent_rank",
|
||||
"persist",
|
||||
"persist_only",
|
||||
"precision",
|
||||
"primary",
|
||||
"procedure",
|
||||
"purge",
|
||||
"range",
|
||||
"rank",
|
||||
"read",
|
||||
"read_write",
|
||||
"reads",
|
||||
"real",
|
||||
"recursive",
|
||||
"references",
|
||||
"regexp",
|
||||
"release",
|
||||
"rename",
|
||||
"repeat",
|
||||
"replace",
|
||||
"require",
|
||||
"resignal",
|
||||
"restrict",
|
||||
"return",
|
||||
"revoke",
|
||||
"right",
|
||||
"rlike",
|
||||
"role",
|
||||
"row",
|
||||
"row_number",
|
||||
"rows",
|
||||
"schema",
|
||||
"schemas",
|
||||
"second_microsecond",
|
||||
"select",
|
||||
"sensitive",
|
||||
"separator",
|
||||
"set",
|
||||
"show",
|
||||
"signal",
|
||||
"slow",
|
||||
"smallint",
|
||||
"spatial",
|
||||
"specific",
|
||||
"sql",
|
||||
"sql_after_gtids",
|
||||
"sql_before_gtids",
|
||||
"sql_big_result",
|
||||
"sql_calc_found_rows",
|
||||
"sql_small_result",
|
||||
"sqlexception",
|
||||
"sqlstate",
|
||||
"sqlwarning",
|
||||
"ssl",
|
||||
"starting",
|
||||
"stored",
|
||||
"straight_join",
|
||||
"system",
|
||||
"table",
|
||||
"terminated",
|
||||
"then",
|
||||
"tinyblob",
|
||||
"tinyint",
|
||||
"tinytext",
|
||||
"to",
|
||||
"trailing",
|
||||
"trigger",
|
||||
"true",
|
||||
"undo",
|
||||
"union",
|
||||
"unique",
|
||||
"unlock",
|
||||
"unsigned",
|
||||
"update",
|
||||
"usage",
|
||||
"use",
|
||||
"using",
|
||||
"utc_date",
|
||||
"utc_time",
|
||||
"utc_timestamp",
|
||||
"values",
|
||||
"varbinary",
|
||||
"varchar",
|
||||
"varcharacter",
|
||||
"varying",
|
||||
"virtual",
|
||||
"when",
|
||||
"where",
|
||||
"while",
|
||||
"window",
|
||||
"with",
|
||||
"write",
|
||||
"xor",
|
||||
"year_month",
|
||||
"zerofill",
|
||||
}
|
|
@ -0,0 +1,773 @@
|
|||
# dialects/mysql/types.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import datetime
|
||||
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...sql import sqltypes
|
||||
|
||||
|
||||
class _NumericType:
|
||||
"""Base for MySQL numeric types.
|
||||
|
||||
This is the base both for NUMERIC as well as INTEGER, hence
|
||||
it's a mixin.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, unsigned=False, zerofill=False, **kw):
|
||||
self.unsigned = unsigned
|
||||
self.zerofill = zerofill
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_NumericType, sqltypes.Numeric]
|
||||
)
|
||||
|
||||
|
||||
class _FloatType(_NumericType, sqltypes.Float):
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
if isinstance(self, (REAL, DOUBLE)) and (
|
||||
(precision is None and scale is not None)
|
||||
or (precision is not None and scale is None)
|
||||
):
|
||||
raise exc.ArgumentError(
|
||||
"You must specify both precision and scale or omit "
|
||||
"both altogether."
|
||||
)
|
||||
super().__init__(precision=precision, asdecimal=asdecimal, **kw)
|
||||
self.scale = scale
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_FloatType, _NumericType, sqltypes.Float]
|
||||
)
|
||||
|
||||
|
||||
class _IntegerType(_NumericType, sqltypes.Integer):
|
||||
def __init__(self, display_width=None, **kw):
|
||||
self.display_width = display_width
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_IntegerType, _NumericType, sqltypes.Integer]
|
||||
)
|
||||
|
||||
|
||||
class _StringType(sqltypes.String):
|
||||
"""Base for MySQL string types."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
charset=None,
|
||||
collation=None,
|
||||
ascii=False, # noqa
|
||||
binary=False,
|
||||
unicode=False,
|
||||
national=False,
|
||||
**kw,
|
||||
):
|
||||
self.charset = charset
|
||||
|
||||
# allow collate= or collation=
|
||||
kw.setdefault("collation", kw.pop("collate", collation))
|
||||
|
||||
self.ascii = ascii
|
||||
self.unicode = unicode
|
||||
self.binary = binary
|
||||
self.national = national
|
||||
super().__init__(**kw)
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(
|
||||
self, to_inspect=[_StringType, sqltypes.String]
|
||||
)
|
||||
|
||||
|
||||
class _MatchType(sqltypes.Float, sqltypes.MatchType):
|
||||
def __init__(self, **kw):
|
||||
# TODO: float arguments?
|
||||
sqltypes.Float.__init__(self)
|
||||
sqltypes.MatchType.__init__(self)
|
||||
|
||||
|
||||
class NUMERIC(_NumericType, sqltypes.NUMERIC):
|
||||
"""MySQL NUMERIC type."""
|
||||
|
||||
__visit_name__ = "NUMERIC"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a NUMERIC.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
|
||||
class DECIMAL(_NumericType, sqltypes.DECIMAL):
|
||||
"""MySQL DECIMAL type."""
|
||||
|
||||
__visit_name__ = "DECIMAL"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DECIMAL.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
|
||||
class DOUBLE(_FloatType, sqltypes.DOUBLE):
|
||||
"""MySQL DOUBLE type."""
|
||||
|
||||
__visit_name__ = "DOUBLE"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a DOUBLE.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.DOUBLE` type by default converts from float
|
||||
to Decimal, using a truncation that defaults to 10 digits.
|
||||
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
|
||||
to change this scale, or ``asdecimal=False`` to return values
|
||||
directly as Python floating points.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
|
||||
class REAL(_FloatType, sqltypes.REAL):
|
||||
"""MySQL REAL type."""
|
||||
|
||||
__visit_name__ = "REAL"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=True, **kw):
|
||||
"""Construct a REAL.
|
||||
|
||||
.. note::
|
||||
|
||||
The :class:`.REAL` type by default converts from float
|
||||
to Decimal, using a truncation that defaults to 10 digits.
|
||||
Specify either ``scale=n`` or ``decimal_return_scale=n`` in order
|
||||
to change this scale, or ``asdecimal=False`` to return values
|
||||
directly as Python floating points.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
|
||||
class FLOAT(_FloatType, sqltypes.FLOAT):
|
||||
"""MySQL FLOAT type."""
|
||||
|
||||
__visit_name__ = "FLOAT"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=False, **kw):
|
||||
"""Construct a FLOAT.
|
||||
|
||||
:param precision: Total digits in this number. If scale and precision
|
||||
are both None, values are stored to limits allowed by the server.
|
||||
|
||||
:param scale: The number of digits after the decimal point.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
precision=precision, scale=scale, asdecimal=asdecimal, **kw
|
||||
)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
class INTEGER(_IntegerType, sqltypes.INTEGER):
|
||||
"""MySQL INTEGER type."""
|
||||
|
||||
__visit_name__ = "INTEGER"
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct an INTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIGINT(_IntegerType, sqltypes.BIGINT):
|
||||
"""MySQL BIGINTEGER type."""
|
||||
|
||||
__visit_name__ = "BIGINT"
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a BIGINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class MEDIUMINT(_IntegerType):
|
||||
"""MySQL MEDIUMINTEGER type."""
|
||||
|
||||
__visit_name__ = "MEDIUMINT"
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a MEDIUMINTEGER
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class TINYINT(_IntegerType):
|
||||
"""MySQL TINYINT type."""
|
||||
|
||||
__visit_name__ = "TINYINT"
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a TINYINT.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class SMALLINT(_IntegerType, sqltypes.SMALLINT):
|
||||
"""MySQL SMALLINTEGER type."""
|
||||
|
||||
__visit_name__ = "SMALLINT"
|
||||
|
||||
def __init__(self, display_width=None, **kw):
|
||||
"""Construct a SMALLINTEGER.
|
||||
|
||||
:param display_width: Optional, maximum display width for this number.
|
||||
|
||||
:param unsigned: a boolean, optional.
|
||||
|
||||
:param zerofill: Optional. If true, values will be stored as strings
|
||||
left-padded with zeros. Note that this does not effect the values
|
||||
returned by the underlying database API, which continue to be
|
||||
numeric.
|
||||
|
||||
"""
|
||||
super().__init__(display_width=display_width, **kw)
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine):
|
||||
"""MySQL BIT type.
|
||||
|
||||
This type is for MySQL 5.0.3 or greater for MyISAM, and 5.0.5 or greater
|
||||
for MyISAM, MEMORY, InnoDB and BDB. For older versions, use a
|
||||
MSTinyInteger() type.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "BIT"
|
||||
|
||||
def __init__(self, length=None):
|
||||
"""Construct a BIT.
|
||||
|
||||
:param length: Optional, number of bits.
|
||||
|
||||
"""
|
||||
self.length = length
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
"""Convert a MySQL's 64 bit, variable length binary string to a long.
|
||||
|
||||
TODO: this is MySQL-db, pyodbc specific. OurSQL and mysqlconnector
|
||||
already do this, so this logic should be moved to those dialects.
|
||||
|
||||
"""
|
||||
|
||||
def process(value):
|
||||
if value is not None:
|
||||
v = 0
|
||||
for i in value:
|
||||
if not isinstance(i, int):
|
||||
i = ord(i) # convert byte to int on Python 2
|
||||
v = v << 8 | i
|
||||
return v
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class TIME(sqltypes.TIME):
|
||||
"""MySQL TIME type."""
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the TIME type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
"""
|
||||
super().__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
time = datetime.time
|
||||
|
||||
def process(value):
|
||||
# convert from a timedelta value
|
||||
if value is not None:
|
||||
microseconds = value.microseconds
|
||||
seconds = value.seconds
|
||||
minutes = seconds // 60
|
||||
return time(
|
||||
minutes // 60,
|
||||
minutes % 60,
|
||||
seconds - minutes * 60,
|
||||
microsecond=microseconds,
|
||||
)
|
||||
else:
|
||||
return None
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""MySQL TIMESTAMP type."""
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL TIMESTAMP type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6.4 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the TIMESTAMP type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
"""
|
||||
super().__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class DATETIME(sqltypes.DATETIME):
|
||||
"""MySQL DATETIME type."""
|
||||
|
||||
__visit_name__ = "DATETIME"
|
||||
|
||||
def __init__(self, timezone=False, fsp=None):
|
||||
"""Construct a MySQL DATETIME type.
|
||||
|
||||
:param timezone: not used by the MySQL dialect.
|
||||
:param fsp: fractional seconds precision value.
|
||||
MySQL 5.6.4 supports storage of fractional seconds;
|
||||
this parameter will be used when emitting DDL
|
||||
for the DATETIME type.
|
||||
|
||||
.. note::
|
||||
|
||||
DBAPI driver support for fractional seconds may
|
||||
be limited; current support includes
|
||||
MySQL Connector/Python.
|
||||
|
||||
"""
|
||||
super().__init__(timezone=timezone)
|
||||
self.fsp = fsp
|
||||
|
||||
|
||||
class YEAR(sqltypes.TypeEngine):
|
||||
"""MySQL YEAR type, for single byte storage of years 1901-2155."""
|
||||
|
||||
__visit_name__ = "YEAR"
|
||||
|
||||
def __init__(self, display_width=None):
|
||||
self.display_width = display_width
|
||||
|
||||
|
||||
class TEXT(_StringType, sqltypes.TEXT):
|
||||
"""MySQL TEXT type, for text up to 2^16 characters."""
|
||||
|
||||
__visit_name__ = "TEXT"
|
||||
|
||||
def __init__(self, length=None, **kw):
|
||||
"""Construct a TEXT.
|
||||
|
||||
:param length: Optional, if provided the server may optimize storage
|
||||
by substituting the smallest TEXT type sufficient to store
|
||||
``length`` characters.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super().__init__(length=length, **kw)
|
||||
|
||||
|
||||
class TINYTEXT(_StringType):
|
||||
"""MySQL TINYTEXT type, for text up to 2^8 characters."""
|
||||
|
||||
__visit_name__ = "TINYTEXT"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a TINYTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class MEDIUMTEXT(_StringType):
|
||||
"""MySQL MEDIUMTEXT type, for text up to 2^24 characters."""
|
||||
|
||||
__visit_name__ = "MEDIUMTEXT"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a MEDIUMTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class LONGTEXT(_StringType):
|
||||
"""MySQL LONGTEXT type, for text up to 2^32 characters."""
|
||||
|
||||
__visit_name__ = "LONGTEXT"
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
"""Construct a LONGTEXT.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class VARCHAR(_StringType, sqltypes.VARCHAR):
|
||||
"""MySQL VARCHAR type, for variable-length character data."""
|
||||
|
||||
__visit_name__ = "VARCHAR"
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a VARCHAR.
|
||||
|
||||
:param charset: Optional, a column-level character set for this string
|
||||
value. Takes precedence to 'ascii' or 'unicode' short-hand.
|
||||
|
||||
:param collation: Optional, a column-level collation for this string
|
||||
value. Takes precedence to 'binary' short-hand.
|
||||
|
||||
:param ascii: Defaults to False: short-hand for the ``latin1``
|
||||
character set, generates ASCII in schema.
|
||||
|
||||
:param unicode: Defaults to False: short-hand for the ``ucs2``
|
||||
character set, generates UNICODE in schema.
|
||||
|
||||
:param national: Optional. If true, use the server's configured
|
||||
national character set.
|
||||
|
||||
:param binary: Defaults to False: short-hand, pick the binary
|
||||
collation type that matches the column's character set. Generates
|
||||
BINARY in schema. This does not affect the type of data stored,
|
||||
only the collation of character data.
|
||||
|
||||
"""
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class CHAR(_StringType, sqltypes.CHAR):
|
||||
"""MySQL CHAR type, for fixed-length character data."""
|
||||
|
||||
__visit_name__ = "CHAR"
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct a CHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def _adapt_string_for_cast(self, type_):
|
||||
# copy the given string type into a CHAR
|
||||
# for the purposes of rendering a CAST expression
|
||||
type_ = sqltypes.to_instance(type_)
|
||||
if isinstance(type_, sqltypes.CHAR):
|
||||
return type_
|
||||
elif isinstance(type_, _StringType):
|
||||
return CHAR(
|
||||
length=type_.length,
|
||||
charset=type_.charset,
|
||||
collation=type_.collation,
|
||||
ascii=type_.ascii,
|
||||
binary=type_.binary,
|
||||
unicode=type_.unicode,
|
||||
national=False, # not supported in CAST
|
||||
)
|
||||
else:
|
||||
return CHAR(length=type_.length)
|
||||
|
||||
|
||||
class NVARCHAR(_StringType, sqltypes.NVARCHAR):
|
||||
"""MySQL NVARCHAR type.
|
||||
|
||||
For variable-length character data in the server's configured national
|
||||
character set.
|
||||
"""
|
||||
|
||||
__visit_name__ = "NVARCHAR"
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NVARCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
kwargs["national"] = True
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class NCHAR(_StringType, sqltypes.NCHAR):
|
||||
"""MySQL NCHAR type.
|
||||
|
||||
For fixed-length character data in the server's configured national
|
||||
character set.
|
||||
"""
|
||||
|
||||
__visit_name__ = "NCHAR"
|
||||
|
||||
def __init__(self, length=None, **kwargs):
|
||||
"""Construct an NCHAR.
|
||||
|
||||
:param length: Maximum data length, in characters.
|
||||
|
||||
:param binary: Optional, use the default binary collation for the
|
||||
national character set. This does not affect the type of data
|
||||
stored, use a BINARY type for binary data.
|
||||
|
||||
:param collation: Optional, request a particular collation. Must be
|
||||
compatible with the national character set.
|
||||
|
||||
"""
|
||||
kwargs["national"] = True
|
||||
super().__init__(length=length, **kwargs)
|
||||
|
||||
|
||||
class TINYBLOB(sqltypes._Binary):
|
||||
"""MySQL TINYBLOB type, for binary data up to 2^8 bytes."""
|
||||
|
||||
__visit_name__ = "TINYBLOB"
|
||||
|
||||
|
||||
class MEDIUMBLOB(sqltypes._Binary):
|
||||
"""MySQL MEDIUMBLOB type, for binary data up to 2^24 bytes."""
|
||||
|
||||
__visit_name__ = "MEDIUMBLOB"
|
||||
|
||||
|
||||
class LONGBLOB(sqltypes._Binary):
|
||||
"""MySQL LONGBLOB type, for binary data up to 2^32 bytes."""
|
||||
|
||||
__visit_name__ = "LONGBLOB"
|
|
@ -0,0 +1,67 @@
|
|||
# dialects/oracle/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from types import ModuleType
|
||||
|
||||
from . import base # noqa
|
||||
from . import cx_oracle # noqa
|
||||
from . import oracledb # noqa
|
||||
from .base import BFILE
|
||||
from .base import BINARY_DOUBLE
|
||||
from .base import BINARY_FLOAT
|
||||
from .base import BLOB
|
||||
from .base import CHAR
|
||||
from .base import CLOB
|
||||
from .base import DATE
|
||||
from .base import DOUBLE_PRECISION
|
||||
from .base import FLOAT
|
||||
from .base import INTERVAL
|
||||
from .base import LONG
|
||||
from .base import NCHAR
|
||||
from .base import NCLOB
|
||||
from .base import NUMBER
|
||||
from .base import NVARCHAR
|
||||
from .base import NVARCHAR2
|
||||
from .base import RAW
|
||||
from .base import REAL
|
||||
from .base import ROWID
|
||||
from .base import TIMESTAMP
|
||||
from .base import VARCHAR
|
||||
from .base import VARCHAR2
|
||||
|
||||
# Alias oracledb also as oracledb_async
|
||||
oracledb_async = type(
|
||||
"oracledb_async", (ModuleType,), {"dialect": oracledb.dialect_async}
|
||||
)
|
||||
|
||||
base.dialect = dialect = cx_oracle.dialect
|
||||
|
||||
__all__ = (
|
||||
"VARCHAR",
|
||||
"NVARCHAR",
|
||||
"CHAR",
|
||||
"NCHAR",
|
||||
"DATE",
|
||||
"NUMBER",
|
||||
"BLOB",
|
||||
"BFILE",
|
||||
"CLOB",
|
||||
"NCLOB",
|
||||
"TIMESTAMP",
|
||||
"RAW",
|
||||
"FLOAT",
|
||||
"DOUBLE_PRECISION",
|
||||
"BINARY_DOUBLE",
|
||||
"BINARY_FLOAT",
|
||||
"LONG",
|
||||
"dialect",
|
||||
"INTERVAL",
|
||||
"VARCHAR2",
|
||||
"NVARCHAR2",
|
||||
"ROWID",
|
||||
"REAL",
|
||||
)
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,507 @@
|
|||
# dialects/oracle/dictionary.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .types import DATE
|
||||
from .types import LONG
|
||||
from .types import NUMBER
|
||||
from .types import RAW
|
||||
from .types import VARCHAR2
|
||||
from ... import Column
|
||||
from ... import MetaData
|
||||
from ... import Table
|
||||
from ... import table
|
||||
from ...sql.sqltypes import CHAR
|
||||
|
||||
# constants
|
||||
DB_LINK_PLACEHOLDER = "__$sa_dblink$__"
|
||||
# tables
|
||||
dual = table("dual")
|
||||
dictionary_meta = MetaData()
|
||||
|
||||
# NOTE: all the dictionary_meta are aliases because oracle does not like
|
||||
# using the full table@dblink for every column in query, and complains with
|
||||
# ORA-00960: ambiguous column naming in select list
|
||||
all_tables = Table(
|
||||
"all_tables" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("tablespace_name", VARCHAR2(30)),
|
||||
Column("cluster_name", VARCHAR2(128)),
|
||||
Column("iot_name", VARCHAR2(128)),
|
||||
Column("status", VARCHAR2(8)),
|
||||
Column("pct_free", NUMBER),
|
||||
Column("pct_used", NUMBER),
|
||||
Column("ini_trans", NUMBER),
|
||||
Column("max_trans", NUMBER),
|
||||
Column("initial_extent", NUMBER),
|
||||
Column("next_extent", NUMBER),
|
||||
Column("min_extents", NUMBER),
|
||||
Column("max_extents", NUMBER),
|
||||
Column("pct_increase", NUMBER),
|
||||
Column("freelists", NUMBER),
|
||||
Column("freelist_groups", NUMBER),
|
||||
Column("logging", VARCHAR2(3)),
|
||||
Column("backed_up", VARCHAR2(1)),
|
||||
Column("num_rows", NUMBER),
|
||||
Column("blocks", NUMBER),
|
||||
Column("empty_blocks", NUMBER),
|
||||
Column("avg_space", NUMBER),
|
||||
Column("chain_cnt", NUMBER),
|
||||
Column("avg_row_len", NUMBER),
|
||||
Column("avg_space_freelist_blocks", NUMBER),
|
||||
Column("num_freelist_blocks", NUMBER),
|
||||
Column("degree", VARCHAR2(10)),
|
||||
Column("instances", VARCHAR2(10)),
|
||||
Column("cache", VARCHAR2(5)),
|
||||
Column("table_lock", VARCHAR2(8)),
|
||||
Column("sample_size", NUMBER),
|
||||
Column("last_analyzed", DATE),
|
||||
Column("partitioned", VARCHAR2(3)),
|
||||
Column("iot_type", VARCHAR2(12)),
|
||||
Column("temporary", VARCHAR2(1)),
|
||||
Column("secondary", VARCHAR2(1)),
|
||||
Column("nested", VARCHAR2(3)),
|
||||
Column("buffer_pool", VARCHAR2(7)),
|
||||
Column("flash_cache", VARCHAR2(7)),
|
||||
Column("cell_flash_cache", VARCHAR2(7)),
|
||||
Column("row_movement", VARCHAR2(8)),
|
||||
Column("global_stats", VARCHAR2(3)),
|
||||
Column("user_stats", VARCHAR2(3)),
|
||||
Column("duration", VARCHAR2(15)),
|
||||
Column("skip_corrupt", VARCHAR2(8)),
|
||||
Column("monitoring", VARCHAR2(3)),
|
||||
Column("cluster_owner", VARCHAR2(128)),
|
||||
Column("dependencies", VARCHAR2(8)),
|
||||
Column("compression", VARCHAR2(8)),
|
||||
Column("compress_for", VARCHAR2(30)),
|
||||
Column("dropped", VARCHAR2(3)),
|
||||
Column("read_only", VARCHAR2(3)),
|
||||
Column("segment_created", VARCHAR2(3)),
|
||||
Column("result_cache", VARCHAR2(7)),
|
||||
Column("clustering", VARCHAR2(3)),
|
||||
Column("activity_tracking", VARCHAR2(23)),
|
||||
Column("dml_timestamp", VARCHAR2(25)),
|
||||
Column("has_identity", VARCHAR2(3)),
|
||||
Column("container_data", VARCHAR2(3)),
|
||||
Column("inmemory", VARCHAR2(8)),
|
||||
Column("inmemory_priority", VARCHAR2(8)),
|
||||
Column("inmemory_distribute", VARCHAR2(15)),
|
||||
Column("inmemory_compression", VARCHAR2(17)),
|
||||
Column("inmemory_duplicate", VARCHAR2(13)),
|
||||
Column("default_collation", VARCHAR2(100)),
|
||||
Column("duplicated", VARCHAR2(1)),
|
||||
Column("sharded", VARCHAR2(1)),
|
||||
Column("externally_sharded", VARCHAR2(1)),
|
||||
Column("externally_duplicated", VARCHAR2(1)),
|
||||
Column("external", VARCHAR2(3)),
|
||||
Column("hybrid", VARCHAR2(3)),
|
||||
Column("cellmemory", VARCHAR2(24)),
|
||||
Column("containers_default", VARCHAR2(3)),
|
||||
Column("container_map", VARCHAR2(3)),
|
||||
Column("extended_data_link", VARCHAR2(3)),
|
||||
Column("extended_data_link_map", VARCHAR2(3)),
|
||||
Column("inmemory_service", VARCHAR2(12)),
|
||||
Column("inmemory_service_name", VARCHAR2(1000)),
|
||||
Column("container_map_object", VARCHAR2(3)),
|
||||
Column("memoptimize_read", VARCHAR2(8)),
|
||||
Column("memoptimize_write", VARCHAR2(8)),
|
||||
Column("has_sensitive_column", VARCHAR2(3)),
|
||||
Column("admit_null", VARCHAR2(3)),
|
||||
Column("data_link_dml_enabled", VARCHAR2(3)),
|
||||
Column("logical_replication", VARCHAR2(8)),
|
||||
).alias("a_tables")
|
||||
|
||||
all_views = Table(
|
||||
"all_views" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("view_name", VARCHAR2(128), nullable=False),
|
||||
Column("text_length", NUMBER),
|
||||
Column("text", LONG),
|
||||
Column("text_vc", VARCHAR2(4000)),
|
||||
Column("type_text_length", NUMBER),
|
||||
Column("type_text", VARCHAR2(4000)),
|
||||
Column("oid_text_length", NUMBER),
|
||||
Column("oid_text", VARCHAR2(4000)),
|
||||
Column("view_type_owner", VARCHAR2(128)),
|
||||
Column("view_type", VARCHAR2(128)),
|
||||
Column("superview_name", VARCHAR2(128)),
|
||||
Column("editioning_view", VARCHAR2(1)),
|
||||
Column("read_only", VARCHAR2(1)),
|
||||
Column("container_data", VARCHAR2(1)),
|
||||
Column("bequeath", VARCHAR2(12)),
|
||||
Column("origin_con_id", VARCHAR2(256)),
|
||||
Column("default_collation", VARCHAR2(100)),
|
||||
Column("containers_default", VARCHAR2(3)),
|
||||
Column("container_map", VARCHAR2(3)),
|
||||
Column("extended_data_link", VARCHAR2(3)),
|
||||
Column("extended_data_link_map", VARCHAR2(3)),
|
||||
Column("has_sensitive_column", VARCHAR2(3)),
|
||||
Column("admit_null", VARCHAR2(3)),
|
||||
Column("pdb_local_only", VARCHAR2(3)),
|
||||
).alias("a_views")
|
||||
|
||||
all_sequences = Table(
|
||||
"all_sequences" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("sequence_owner", VARCHAR2(128), nullable=False),
|
||||
Column("sequence_name", VARCHAR2(128), nullable=False),
|
||||
Column("min_value", NUMBER),
|
||||
Column("max_value", NUMBER),
|
||||
Column("increment_by", NUMBER, nullable=False),
|
||||
Column("cycle_flag", VARCHAR2(1)),
|
||||
Column("order_flag", VARCHAR2(1)),
|
||||
Column("cache_size", NUMBER, nullable=False),
|
||||
Column("last_number", NUMBER, nullable=False),
|
||||
Column("scale_flag", VARCHAR2(1)),
|
||||
Column("extend_flag", VARCHAR2(1)),
|
||||
Column("sharded_flag", VARCHAR2(1)),
|
||||
Column("session_flag", VARCHAR2(1)),
|
||||
Column("keep_value", VARCHAR2(1)),
|
||||
).alias("a_sequences")
|
||||
|
||||
all_users = Table(
|
||||
"all_users" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("username", VARCHAR2(128), nullable=False),
|
||||
Column("user_id", NUMBER, nullable=False),
|
||||
Column("created", DATE, nullable=False),
|
||||
Column("common", VARCHAR2(3)),
|
||||
Column("oracle_maintained", VARCHAR2(1)),
|
||||
Column("inherited", VARCHAR2(3)),
|
||||
Column("default_collation", VARCHAR2(100)),
|
||||
Column("implicit", VARCHAR2(3)),
|
||||
Column("all_shard", VARCHAR2(3)),
|
||||
Column("external_shard", VARCHAR2(3)),
|
||||
).alias("a_users")
|
||||
|
||||
all_mviews = Table(
|
||||
"all_mviews" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("mview_name", VARCHAR2(128), nullable=False),
|
||||
Column("container_name", VARCHAR2(128), nullable=False),
|
||||
Column("query", LONG),
|
||||
Column("query_len", NUMBER(38)),
|
||||
Column("updatable", VARCHAR2(1)),
|
||||
Column("update_log", VARCHAR2(128)),
|
||||
Column("master_rollback_seg", VARCHAR2(128)),
|
||||
Column("master_link", VARCHAR2(128)),
|
||||
Column("rewrite_enabled", VARCHAR2(1)),
|
||||
Column("rewrite_capability", VARCHAR2(9)),
|
||||
Column("refresh_mode", VARCHAR2(6)),
|
||||
Column("refresh_method", VARCHAR2(8)),
|
||||
Column("build_mode", VARCHAR2(9)),
|
||||
Column("fast_refreshable", VARCHAR2(18)),
|
||||
Column("last_refresh_type", VARCHAR2(8)),
|
||||
Column("last_refresh_date", DATE),
|
||||
Column("last_refresh_end_time", DATE),
|
||||
Column("staleness", VARCHAR2(19)),
|
||||
Column("after_fast_refresh", VARCHAR2(19)),
|
||||
Column("unknown_prebuilt", VARCHAR2(1)),
|
||||
Column("unknown_plsql_func", VARCHAR2(1)),
|
||||
Column("unknown_external_table", VARCHAR2(1)),
|
||||
Column("unknown_consider_fresh", VARCHAR2(1)),
|
||||
Column("unknown_import", VARCHAR2(1)),
|
||||
Column("unknown_trusted_fd", VARCHAR2(1)),
|
||||
Column("compile_state", VARCHAR2(19)),
|
||||
Column("use_no_index", VARCHAR2(1)),
|
||||
Column("stale_since", DATE),
|
||||
Column("num_pct_tables", NUMBER),
|
||||
Column("num_fresh_pct_regions", NUMBER),
|
||||
Column("num_stale_pct_regions", NUMBER),
|
||||
Column("segment_created", VARCHAR2(3)),
|
||||
Column("evaluation_edition", VARCHAR2(128)),
|
||||
Column("unusable_before", VARCHAR2(128)),
|
||||
Column("unusable_beginning", VARCHAR2(128)),
|
||||
Column("default_collation", VARCHAR2(100)),
|
||||
Column("on_query_computation", VARCHAR2(1)),
|
||||
Column("auto", VARCHAR2(3)),
|
||||
).alias("a_mviews")
|
||||
|
||||
all_tab_identity_cols = Table(
|
||||
"all_tab_identity_cols" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_name", VARCHAR2(128), nullable=False),
|
||||
Column("generation_type", VARCHAR2(10)),
|
||||
Column("sequence_name", VARCHAR2(128), nullable=False),
|
||||
Column("identity_options", VARCHAR2(298)),
|
||||
).alias("a_tab_identity_cols")
|
||||
|
||||
all_tab_cols = Table(
|
||||
"all_tab_cols" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_name", VARCHAR2(128), nullable=False),
|
||||
Column("data_type", VARCHAR2(128)),
|
||||
Column("data_type_mod", VARCHAR2(3)),
|
||||
Column("data_type_owner", VARCHAR2(128)),
|
||||
Column("data_length", NUMBER, nullable=False),
|
||||
Column("data_precision", NUMBER),
|
||||
Column("data_scale", NUMBER),
|
||||
Column("nullable", VARCHAR2(1)),
|
||||
Column("column_id", NUMBER),
|
||||
Column("default_length", NUMBER),
|
||||
Column("data_default", LONG),
|
||||
Column("num_distinct", NUMBER),
|
||||
Column("low_value", RAW(1000)),
|
||||
Column("high_value", RAW(1000)),
|
||||
Column("density", NUMBER),
|
||||
Column("num_nulls", NUMBER),
|
||||
Column("num_buckets", NUMBER),
|
||||
Column("last_analyzed", DATE),
|
||||
Column("sample_size", NUMBER),
|
||||
Column("character_set_name", VARCHAR2(44)),
|
||||
Column("char_col_decl_length", NUMBER),
|
||||
Column("global_stats", VARCHAR2(3)),
|
||||
Column("user_stats", VARCHAR2(3)),
|
||||
Column("avg_col_len", NUMBER),
|
||||
Column("char_length", NUMBER),
|
||||
Column("char_used", VARCHAR2(1)),
|
||||
Column("v80_fmt_image", VARCHAR2(3)),
|
||||
Column("data_upgraded", VARCHAR2(3)),
|
||||
Column("hidden_column", VARCHAR2(3)),
|
||||
Column("virtual_column", VARCHAR2(3)),
|
||||
Column("segment_column_id", NUMBER),
|
||||
Column("internal_column_id", NUMBER, nullable=False),
|
||||
Column("histogram", VARCHAR2(15)),
|
||||
Column("qualified_col_name", VARCHAR2(4000)),
|
||||
Column("user_generated", VARCHAR2(3)),
|
||||
Column("default_on_null", VARCHAR2(3)),
|
||||
Column("identity_column", VARCHAR2(3)),
|
||||
Column("evaluation_edition", VARCHAR2(128)),
|
||||
Column("unusable_before", VARCHAR2(128)),
|
||||
Column("unusable_beginning", VARCHAR2(128)),
|
||||
Column("collation", VARCHAR2(100)),
|
||||
Column("collated_column_id", NUMBER),
|
||||
).alias("a_tab_cols")
|
||||
|
||||
all_tab_comments = Table(
|
||||
"all_tab_comments" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("table_type", VARCHAR2(11)),
|
||||
Column("comments", VARCHAR2(4000)),
|
||||
Column("origin_con_id", NUMBER),
|
||||
).alias("a_tab_comments")
|
||||
|
||||
all_col_comments = Table(
|
||||
"all_col_comments" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_name", VARCHAR2(128), nullable=False),
|
||||
Column("comments", VARCHAR2(4000)),
|
||||
Column("origin_con_id", NUMBER),
|
||||
).alias("a_col_comments")
|
||||
|
||||
all_mview_comments = Table(
|
||||
"all_mview_comments" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("mview_name", VARCHAR2(128), nullable=False),
|
||||
Column("comments", VARCHAR2(4000)),
|
||||
).alias("a_mview_comments")
|
||||
|
||||
all_ind_columns = Table(
|
||||
"all_ind_columns" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("index_owner", VARCHAR2(128), nullable=False),
|
||||
Column("index_name", VARCHAR2(128), nullable=False),
|
||||
Column("table_owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_name", VARCHAR2(4000)),
|
||||
Column("column_position", NUMBER, nullable=False),
|
||||
Column("column_length", NUMBER, nullable=False),
|
||||
Column("char_length", NUMBER),
|
||||
Column("descend", VARCHAR2(4)),
|
||||
Column("collated_column_id", NUMBER),
|
||||
).alias("a_ind_columns")
|
||||
|
||||
all_indexes = Table(
|
||||
"all_indexes" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("index_name", VARCHAR2(128), nullable=False),
|
||||
Column("index_type", VARCHAR2(27)),
|
||||
Column("table_owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("table_type", CHAR(11)),
|
||||
Column("uniqueness", VARCHAR2(9)),
|
||||
Column("compression", VARCHAR2(13)),
|
||||
Column("prefix_length", NUMBER),
|
||||
Column("tablespace_name", VARCHAR2(30)),
|
||||
Column("ini_trans", NUMBER),
|
||||
Column("max_trans", NUMBER),
|
||||
Column("initial_extent", NUMBER),
|
||||
Column("next_extent", NUMBER),
|
||||
Column("min_extents", NUMBER),
|
||||
Column("max_extents", NUMBER),
|
||||
Column("pct_increase", NUMBER),
|
||||
Column("pct_threshold", NUMBER),
|
||||
Column("include_column", NUMBER),
|
||||
Column("freelists", NUMBER),
|
||||
Column("freelist_groups", NUMBER),
|
||||
Column("pct_free", NUMBER),
|
||||
Column("logging", VARCHAR2(3)),
|
||||
Column("blevel", NUMBER),
|
||||
Column("leaf_blocks", NUMBER),
|
||||
Column("distinct_keys", NUMBER),
|
||||
Column("avg_leaf_blocks_per_key", NUMBER),
|
||||
Column("avg_data_blocks_per_key", NUMBER),
|
||||
Column("clustering_factor", NUMBER),
|
||||
Column("status", VARCHAR2(8)),
|
||||
Column("num_rows", NUMBER),
|
||||
Column("sample_size", NUMBER),
|
||||
Column("last_analyzed", DATE),
|
||||
Column("degree", VARCHAR2(40)),
|
||||
Column("instances", VARCHAR2(40)),
|
||||
Column("partitioned", VARCHAR2(3)),
|
||||
Column("temporary", VARCHAR2(1)),
|
||||
Column("generated", VARCHAR2(1)),
|
||||
Column("secondary", VARCHAR2(1)),
|
||||
Column("buffer_pool", VARCHAR2(7)),
|
||||
Column("flash_cache", VARCHAR2(7)),
|
||||
Column("cell_flash_cache", VARCHAR2(7)),
|
||||
Column("user_stats", VARCHAR2(3)),
|
||||
Column("duration", VARCHAR2(15)),
|
||||
Column("pct_direct_access", NUMBER),
|
||||
Column("ityp_owner", VARCHAR2(128)),
|
||||
Column("ityp_name", VARCHAR2(128)),
|
||||
Column("parameters", VARCHAR2(1000)),
|
||||
Column("global_stats", VARCHAR2(3)),
|
||||
Column("domidx_status", VARCHAR2(12)),
|
||||
Column("domidx_opstatus", VARCHAR2(6)),
|
||||
Column("funcidx_status", VARCHAR2(8)),
|
||||
Column("join_index", VARCHAR2(3)),
|
||||
Column("iot_redundant_pkey_elim", VARCHAR2(3)),
|
||||
Column("dropped", VARCHAR2(3)),
|
||||
Column("visibility", VARCHAR2(9)),
|
||||
Column("domidx_management", VARCHAR2(14)),
|
||||
Column("segment_created", VARCHAR2(3)),
|
||||
Column("orphaned_entries", VARCHAR2(3)),
|
||||
Column("indexing", VARCHAR2(7)),
|
||||
Column("auto", VARCHAR2(3)),
|
||||
).alias("a_indexes")
|
||||
|
||||
all_ind_expressions = Table(
|
||||
"all_ind_expressions" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("index_owner", VARCHAR2(128), nullable=False),
|
||||
Column("index_name", VARCHAR2(128), nullable=False),
|
||||
Column("table_owner", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_expression", LONG),
|
||||
Column("column_position", NUMBER, nullable=False),
|
||||
).alias("a_ind_expressions")
|
||||
|
||||
all_constraints = Table(
|
||||
"all_constraints" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128)),
|
||||
Column("constraint_name", VARCHAR2(128)),
|
||||
Column("constraint_type", VARCHAR2(1)),
|
||||
Column("table_name", VARCHAR2(128)),
|
||||
Column("search_condition", LONG),
|
||||
Column("search_condition_vc", VARCHAR2(4000)),
|
||||
Column("r_owner", VARCHAR2(128)),
|
||||
Column("r_constraint_name", VARCHAR2(128)),
|
||||
Column("delete_rule", VARCHAR2(9)),
|
||||
Column("status", VARCHAR2(8)),
|
||||
Column("deferrable", VARCHAR2(14)),
|
||||
Column("deferred", VARCHAR2(9)),
|
||||
Column("validated", VARCHAR2(13)),
|
||||
Column("generated", VARCHAR2(14)),
|
||||
Column("bad", VARCHAR2(3)),
|
||||
Column("rely", VARCHAR2(4)),
|
||||
Column("last_change", DATE),
|
||||
Column("index_owner", VARCHAR2(128)),
|
||||
Column("index_name", VARCHAR2(128)),
|
||||
Column("invalid", VARCHAR2(7)),
|
||||
Column("view_related", VARCHAR2(14)),
|
||||
Column("origin_con_id", VARCHAR2(256)),
|
||||
).alias("a_constraints")
|
||||
|
||||
all_cons_columns = Table(
|
||||
"all_cons_columns" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("constraint_name", VARCHAR2(128), nullable=False),
|
||||
Column("table_name", VARCHAR2(128), nullable=False),
|
||||
Column("column_name", VARCHAR2(4000)),
|
||||
Column("position", NUMBER),
|
||||
).alias("a_cons_columns")
|
||||
|
||||
# TODO figure out if it's still relevant, since there is no mention from here
|
||||
# https://docs.oracle.com/en/database/oracle/oracle-database/21/refrn/ALL_DB_LINKS.html
|
||||
# original note:
|
||||
# using user_db_links here since all_db_links appears
|
||||
# to have more restricted permissions.
|
||||
# https://docs.oracle.com/cd/B28359_01/server.111/b28310/ds_admin005.htm
|
||||
# will need to hear from more users if we are doing
|
||||
# the right thing here. See [ticket:2619]
|
||||
all_db_links = Table(
|
||||
"all_db_links" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("db_link", VARCHAR2(128), nullable=False),
|
||||
Column("username", VARCHAR2(128)),
|
||||
Column("host", VARCHAR2(2000)),
|
||||
Column("created", DATE, nullable=False),
|
||||
Column("hidden", VARCHAR2(3)),
|
||||
Column("shard_internal", VARCHAR2(3)),
|
||||
Column("valid", VARCHAR2(3)),
|
||||
Column("intra_cdb", VARCHAR2(3)),
|
||||
).alias("a_db_links")
|
||||
|
||||
all_synonyms = Table(
|
||||
"all_synonyms" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128)),
|
||||
Column("synonym_name", VARCHAR2(128)),
|
||||
Column("table_owner", VARCHAR2(128)),
|
||||
Column("table_name", VARCHAR2(128)),
|
||||
Column("db_link", VARCHAR2(128)),
|
||||
Column("origin_con_id", VARCHAR2(256)),
|
||||
).alias("a_synonyms")
|
||||
|
||||
all_objects = Table(
|
||||
"all_objects" + DB_LINK_PLACEHOLDER,
|
||||
dictionary_meta,
|
||||
Column("owner", VARCHAR2(128), nullable=False),
|
||||
Column("object_name", VARCHAR2(128), nullable=False),
|
||||
Column("subobject_name", VARCHAR2(128)),
|
||||
Column("object_id", NUMBER, nullable=False),
|
||||
Column("data_object_id", NUMBER),
|
||||
Column("object_type", VARCHAR2(23)),
|
||||
Column("created", DATE, nullable=False),
|
||||
Column("last_ddl_time", DATE, nullable=False),
|
||||
Column("timestamp", VARCHAR2(19)),
|
||||
Column("status", VARCHAR2(7)),
|
||||
Column("temporary", VARCHAR2(1)),
|
||||
Column("generated", VARCHAR2(1)),
|
||||
Column("secondary", VARCHAR2(1)),
|
||||
Column("namespace", NUMBER, nullable=False),
|
||||
Column("edition_name", VARCHAR2(128)),
|
||||
Column("sharing", VARCHAR2(13)),
|
||||
Column("editionable", VARCHAR2(1)),
|
||||
Column("oracle_maintained", VARCHAR2(1)),
|
||||
Column("application", VARCHAR2(1)),
|
||||
Column("default_collation", VARCHAR2(100)),
|
||||
Column("duplicated", VARCHAR2(1)),
|
||||
Column("sharded", VARCHAR2(1)),
|
||||
Column("created_appid", NUMBER),
|
||||
Column("created_vsnid", NUMBER),
|
||||
Column("modified_appid", NUMBER),
|
||||
Column("modified_vsnid", NUMBER),
|
||||
).alias("a_objects")
|
|
@ -0,0 +1,311 @@
|
|||
# dialects/oracle/oracledb.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: oracle+oracledb
|
||||
:name: python-oracledb
|
||||
:dbapi: oracledb
|
||||
:connectstring: oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=<service>[&key=value&key=value...]]
|
||||
:url: https://oracle.github.io/python-oracledb/
|
||||
|
||||
python-oracledb is released by Oracle to supersede the cx_Oracle driver.
|
||||
It is fully compatible with cx_Oracle and features both a "thin" client
|
||||
mode that requires no dependencies, as well as a "thick" mode that uses
|
||||
the Oracle Client Interface in the same way as cx_Oracle.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`cx_oracle` - all of cx_Oracle's notes apply to the oracledb driver
|
||||
as well.
|
||||
|
||||
The SQLAlchemy ``oracledb`` dialect provides both a sync and an async
|
||||
implementation under the same dialect name. The proper version is
|
||||
selected depending on how the engine is created:
|
||||
|
||||
* calling :func:`_sa.create_engine` with ``oracle+oracledb://...`` will
|
||||
automatically select the sync version, e.g.::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
sync_engine = create_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
|
||||
|
||||
* calling :func:`_asyncio.create_async_engine` with
|
||||
``oracle+oracledb://...`` will automatically select the async version,
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("oracle+oracledb://scott:tiger@localhost/?service_name=XEPDB1")
|
||||
|
||||
The asyncio version of the dialect may also be specified explicitly using the
|
||||
``oracledb_async`` suffix, as::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("oracle+oracledb_async://scott:tiger@localhost/?service_name=XEPDB1")
|
||||
|
||||
.. versionadded:: 2.0.25 added support for the async version of oracledb.
|
||||
|
||||
Thick mode support
|
||||
------------------
|
||||
|
||||
By default the ``python-oracledb`` is started in thin mode, that does not
|
||||
require oracle client libraries to be installed in the system. The
|
||||
``python-oracledb`` driver also support a "thick" mode, that behaves
|
||||
similarly to ``cx_oracle`` and requires that Oracle Client Interface (OCI)
|
||||
is installed.
|
||||
|
||||
To enable this mode, the user may call ``oracledb.init_oracle_client``
|
||||
manually, or by passing the parameter ``thick_mode=True`` to
|
||||
:func:`_sa.create_engine`. To pass custom arguments to ``init_oracle_client``,
|
||||
like the ``lib_dir`` path, a dict may be passed to this parameter, as in::
|
||||
|
||||
engine = sa.create_engine("oracle+oracledb://...", thick_mode={
|
||||
"lib_dir": "/path/to/oracle/client/lib", "driver_name": "my-app"
|
||||
})
|
||||
|
||||
.. seealso::
|
||||
|
||||
https://python-oracledb.readthedocs.io/en/latest/api_manual/module.html#oracledb.init_oracle_client
|
||||
|
||||
|
||||
.. versionadded:: 2.0.0 added support for oracledb driver.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from .cx_oracle import OracleDialect_cx_oracle as _OracleDialect_cx_oracle
|
||||
from ... import exc
|
||||
from ... import pool
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_connection
|
||||
from ...connectors.asyncio import AsyncAdapt_dbapi_cursor
|
||||
from ...connectors.asyncio import AsyncAdaptFallback_dbapi_connection
|
||||
from ...util import asbool
|
||||
from ...util import await_fallback
|
||||
from ...util import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from oracledb import AsyncConnection
|
||||
from oracledb import AsyncCursor
|
||||
|
||||
|
||||
class OracleDialect_oracledb(_OracleDialect_cx_oracle):
|
||||
supports_statement_cache = True
|
||||
driver = "oracledb"
|
||||
_min_version = (1,)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
auto_convert_lobs=True,
|
||||
coerce_to_decimal=True,
|
||||
arraysize=None,
|
||||
encoding_errors=None,
|
||||
thick_mode=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
auto_convert_lobs,
|
||||
coerce_to_decimal,
|
||||
arraysize,
|
||||
encoding_errors,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.dbapi is not None and (
|
||||
thick_mode or isinstance(thick_mode, dict)
|
||||
):
|
||||
kw = thick_mode if isinstance(thick_mode, dict) else {}
|
||||
self.dbapi.init_oracle_client(**kw)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import oracledb
|
||||
|
||||
return oracledb
|
||||
|
||||
@classmethod
|
||||
def is_thin_mode(cls, connection):
|
||||
return connection.connection.dbapi_connection.thin
|
||||
|
||||
@classmethod
|
||||
def get_async_dialect_cls(cls, url):
|
||||
return OracleDialectAsync_oracledb
|
||||
|
||||
def _load_version(self, dbapi_module):
|
||||
version = (0, 0, 0)
|
||||
if dbapi_module is not None:
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", dbapi_module.version)
|
||||
if m:
|
||||
version = tuple(
|
||||
int(x) for x in m.group(1, 2, 3) if x is not None
|
||||
)
|
||||
self.oracledb_ver = version
|
||||
if (
|
||||
self.oracledb_ver > (0, 0, 0)
|
||||
and self.oracledb_ver < self._min_version
|
||||
):
|
||||
raise exc.InvalidRequestError(
|
||||
f"oracledb version {self._min_version} and above are supported"
|
||||
)
|
||||
|
||||
|
||||
class AsyncAdapt_oracledb_cursor(AsyncAdapt_dbapi_cursor):
|
||||
_cursor: AsyncCursor
|
||||
__slots__ = ()
|
||||
|
||||
@property
|
||||
def outputtypehandler(self):
|
||||
return self._cursor.outputtypehandler
|
||||
|
||||
@outputtypehandler.setter
|
||||
def outputtypehandler(self, value):
|
||||
self._cursor.outputtypehandler = value
|
||||
|
||||
def var(self, *args, **kwargs):
|
||||
return self._cursor.var(*args, **kwargs)
|
||||
|
||||
def close(self):
|
||||
self._rows.clear()
|
||||
self._cursor.close()
|
||||
|
||||
def setinputsizes(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self._cursor.setinputsizes(*args, **kwargs)
|
||||
|
||||
def _aenter_cursor(self, cursor: AsyncCursor) -> AsyncCursor:
|
||||
try:
|
||||
return cursor.__enter__()
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
async def _execute_async(self, operation, parameters):
|
||||
# override to not use mutex, oracledb already has mutex
|
||||
|
||||
if parameters is None:
|
||||
result = await self._cursor.execute(operation)
|
||||
else:
|
||||
result = await self._cursor.execute(operation, parameters)
|
||||
|
||||
if self._cursor.description and not self.server_side:
|
||||
self._rows = collections.deque(await self._cursor.fetchall())
|
||||
return result
|
||||
|
||||
async def _executemany_async(
|
||||
self,
|
||||
operation,
|
||||
seq_of_parameters,
|
||||
):
|
||||
# override to not use mutex, oracledb already has mutex
|
||||
return await self._cursor.executemany(operation, seq_of_parameters)
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, type_: Any, value: Any, traceback: Any) -> None:
|
||||
self.close()
|
||||
|
||||
|
||||
class AsyncAdapt_oracledb_connection(AsyncAdapt_dbapi_connection):
|
||||
_connection: AsyncConnection
|
||||
__slots__ = ()
|
||||
|
||||
thin = True
|
||||
|
||||
_cursor_cls = AsyncAdapt_oracledb_cursor
|
||||
_ss_cursor_cls = None
|
||||
|
||||
@property
|
||||
def autocommit(self):
|
||||
return self._connection.autocommit
|
||||
|
||||
@autocommit.setter
|
||||
def autocommit(self, value):
|
||||
self._connection.autocommit = value
|
||||
|
||||
@property
|
||||
def outputtypehandler(self):
|
||||
return self._connection.outputtypehandler
|
||||
|
||||
@outputtypehandler.setter
|
||||
def outputtypehandler(self, value):
|
||||
self._connection.outputtypehandler = value
|
||||
|
||||
@property
|
||||
def version(self):
|
||||
return self._connection.version
|
||||
|
||||
@property
|
||||
def stmtcachesize(self):
|
||||
return self._connection.stmtcachesize
|
||||
|
||||
@stmtcachesize.setter
|
||||
def stmtcachesize(self, value):
|
||||
self._connection.stmtcachesize = value
|
||||
|
||||
def cursor(self):
|
||||
return AsyncAdapt_oracledb_cursor(self)
|
||||
|
||||
|
||||
class AsyncAdaptFallback_oracledb_connection(
|
||||
AsyncAdaptFallback_dbapi_connection, AsyncAdapt_oracledb_connection
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
|
||||
class OracledbAdaptDBAPI:
|
||||
def __init__(self, oracledb) -> None:
|
||||
self.oracledb = oracledb
|
||||
|
||||
for k, v in self.oracledb.__dict__.items():
|
||||
if k != "connect":
|
||||
self.__dict__[k] = v
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop("async_creator_fn", self.oracledb.connect_async)
|
||||
|
||||
if asbool(async_fallback):
|
||||
return AsyncAdaptFallback_oracledb_connection(
|
||||
self, await_fallback(creator_fn(*arg, **kw))
|
||||
)
|
||||
|
||||
else:
|
||||
return AsyncAdapt_oracledb_connection(
|
||||
self, await_only(creator_fn(*arg, **kw))
|
||||
)
|
||||
|
||||
|
||||
class OracleDialectAsync_oracledb(OracleDialect_oracledb):
|
||||
is_async = True
|
||||
supports_statement_cache = True
|
||||
|
||||
_min_version = (2,)
|
||||
|
||||
# thick_mode mode is not supported by asyncio, oracledb will raise
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import oracledb
|
||||
|
||||
return OracledbAdaptDBAPI(oracledb)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if asbool(async_fallback):
|
||||
return pool.FallbackAsyncAdaptedQueuePool
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = OracleDialect_oracledb
|
||||
dialect_async = OracleDialectAsync_oracledb
|
|
@ -0,0 +1,220 @@
|
|||
# dialects/oracle/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import create_engine
|
||||
from ... import exc
|
||||
from ... import inspect
|
||||
from ...engine import url as sa_url
|
||||
from ...testing.provision import configure_follower
|
||||
from ...testing.provision import create_db
|
||||
from ...testing.provision import drop_all_schema_objects_post_tables
|
||||
from ...testing.provision import drop_all_schema_objects_pre_tables
|
||||
from ...testing.provision import drop_db
|
||||
from ...testing.provision import follower_url_from_main
|
||||
from ...testing.provision import log
|
||||
from ...testing.provision import post_configure_engine
|
||||
from ...testing.provision import run_reap_dbs
|
||||
from ...testing.provision import set_default_schema_on_connection
|
||||
from ...testing.provision import stop_test_class_outside_fixtures
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
from ...testing.provision import update_db_opts
|
||||
|
||||
|
||||
@create_db.for_db("oracle")
|
||||
def _oracle_create_db(cfg, eng, ident):
|
||||
# NOTE: make sure you've run "ALTER DATABASE default tablespace users" or
|
||||
# similar, so that the default tablespace is not "system"; reflection will
|
||||
# fail otherwise
|
||||
with eng.begin() as conn:
|
||||
conn.exec_driver_sql("create user %s identified by xe" % ident)
|
||||
conn.exec_driver_sql("create user %s_ts1 identified by xe" % ident)
|
||||
conn.exec_driver_sql("create user %s_ts2 identified by xe" % ident)
|
||||
conn.exec_driver_sql("grant dba to %s" % (ident,))
|
||||
conn.exec_driver_sql("grant unlimited tablespace to %s" % ident)
|
||||
conn.exec_driver_sql("grant unlimited tablespace to %s_ts1" % ident)
|
||||
conn.exec_driver_sql("grant unlimited tablespace to %s_ts2" % ident)
|
||||
# these are needed to create materialized views
|
||||
conn.exec_driver_sql("grant create table to %s" % ident)
|
||||
conn.exec_driver_sql("grant create table to %s_ts1" % ident)
|
||||
conn.exec_driver_sql("grant create table to %s_ts2" % ident)
|
||||
|
||||
|
||||
@configure_follower.for_db("oracle")
|
||||
def _oracle_configure_follower(config, ident):
|
||||
config.test_schema = "%s_ts1" % ident
|
||||
config.test_schema_2 = "%s_ts2" % ident
|
||||
|
||||
|
||||
def _ora_drop_ignore(conn, dbname):
|
||||
try:
|
||||
conn.exec_driver_sql("drop user %s cascade" % dbname)
|
||||
log.info("Reaped db: %s", dbname)
|
||||
return True
|
||||
except exc.DatabaseError as err:
|
||||
log.warning("couldn't drop db: %s", err)
|
||||
return False
|
||||
|
||||
|
||||
@drop_all_schema_objects_pre_tables.for_db("oracle")
|
||||
def _ora_drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
_purge_recyclebin(eng)
|
||||
_purge_recyclebin(eng, cfg.test_schema)
|
||||
|
||||
|
||||
@drop_all_schema_objects_post_tables.for_db("oracle")
|
||||
def _ora_drop_all_schema_objects_post_tables(cfg, eng):
|
||||
with eng.begin() as conn:
|
||||
for syn in conn.dialect._get_synonyms(conn, None, None, None):
|
||||
conn.exec_driver_sql(f"drop synonym {syn['synonym_name']}")
|
||||
|
||||
for syn in conn.dialect._get_synonyms(
|
||||
conn, cfg.test_schema, None, None
|
||||
):
|
||||
conn.exec_driver_sql(
|
||||
f"drop synonym {cfg.test_schema}.{syn['synonym_name']}"
|
||||
)
|
||||
|
||||
for tmp_table in inspect(conn).get_temp_table_names():
|
||||
conn.exec_driver_sql(f"drop table {tmp_table}")
|
||||
|
||||
|
||||
@drop_db.for_db("oracle")
|
||||
def _oracle_drop_db(cfg, eng, ident):
|
||||
with eng.begin() as conn:
|
||||
# cx_Oracle seems to occasionally leak open connections when a large
|
||||
# suite it run, even if we confirm we have zero references to
|
||||
# connection objects.
|
||||
# while there is a "kill session" command in Oracle,
|
||||
# it unfortunately does not release the connection sufficiently.
|
||||
_ora_drop_ignore(conn, ident)
|
||||
_ora_drop_ignore(conn, "%s_ts1" % ident)
|
||||
_ora_drop_ignore(conn, "%s_ts2" % ident)
|
||||
|
||||
|
||||
@stop_test_class_outside_fixtures.for_db("oracle")
|
||||
def _ora_stop_test_class_outside_fixtures(config, db, cls):
|
||||
try:
|
||||
_purge_recyclebin(db)
|
||||
except exc.DatabaseError as err:
|
||||
log.warning("purge recyclebin command failed: %s", err)
|
||||
|
||||
# clear statement cache on all connections that were used
|
||||
# https://github.com/oracle/python-cx_Oracle/issues/519
|
||||
|
||||
for cx_oracle_conn in _all_conns:
|
||||
try:
|
||||
sc = cx_oracle_conn.stmtcachesize
|
||||
except db.dialect.dbapi.InterfaceError:
|
||||
# connection closed
|
||||
pass
|
||||
else:
|
||||
cx_oracle_conn.stmtcachesize = 0
|
||||
cx_oracle_conn.stmtcachesize = sc
|
||||
_all_conns.clear()
|
||||
|
||||
|
||||
def _purge_recyclebin(eng, schema=None):
|
||||
with eng.begin() as conn:
|
||||
if schema is None:
|
||||
# run magic command to get rid of identity sequences
|
||||
# https://floo.bar/2019/11/29/drop-the-underlying-sequence-of-an-identity-column/ # noqa: E501
|
||||
conn.exec_driver_sql("purge recyclebin")
|
||||
else:
|
||||
# per user: https://community.oracle.com/tech/developers/discussion/2255402/how-to-clear-dba-recyclebin-for-a-particular-user # noqa: E501
|
||||
for owner, object_name, type_ in conn.exec_driver_sql(
|
||||
"select owner, object_name,type from "
|
||||
"dba_recyclebin where owner=:schema and type='TABLE'",
|
||||
{"schema": conn.dialect.denormalize_name(schema)},
|
||||
).all():
|
||||
conn.exec_driver_sql(f'purge {type_} {owner}."{object_name}"')
|
||||
|
||||
|
||||
_all_conns = set()
|
||||
|
||||
|
||||
@post_configure_engine.for_db("oracle")
|
||||
def _oracle_post_configure_engine(url, engine, follower_ident):
|
||||
from sqlalchemy import event
|
||||
|
||||
@event.listens_for(engine, "checkout")
|
||||
def checkout(dbapi_con, con_record, con_proxy):
|
||||
_all_conns.add(dbapi_con)
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def checkin(dbapi_connection, connection_record):
|
||||
# work around cx_Oracle issue:
|
||||
# https://github.com/oracle/python-cx_Oracle/issues/530
|
||||
# invalidate oracle connections that had 2pc set up
|
||||
if "cx_oracle_xid" in connection_record.info:
|
||||
connection_record.invalidate()
|
||||
|
||||
|
||||
@run_reap_dbs.for_db("oracle")
|
||||
def _reap_oracle_dbs(url, idents):
|
||||
log.info("db reaper connecting to %r", url)
|
||||
eng = create_engine(url)
|
||||
with eng.begin() as conn:
|
||||
log.info("identifiers in file: %s", ", ".join(idents))
|
||||
|
||||
to_reap = conn.exec_driver_sql(
|
||||
"select u.username from all_users u where username "
|
||||
"like 'TEST_%' and not exists (select username "
|
||||
"from v$session where username=u.username)"
|
||||
)
|
||||
all_names = {username.lower() for (username,) in to_reap}
|
||||
to_drop = set()
|
||||
for name in all_names:
|
||||
if name.endswith("_ts1") or name.endswith("_ts2"):
|
||||
continue
|
||||
elif name in idents:
|
||||
to_drop.add(name)
|
||||
if "%s_ts1" % name in all_names:
|
||||
to_drop.add("%s_ts1" % name)
|
||||
if "%s_ts2" % name in all_names:
|
||||
to_drop.add("%s_ts2" % name)
|
||||
|
||||
dropped = total = 0
|
||||
for total, username in enumerate(to_drop, 1):
|
||||
if _ora_drop_ignore(conn, username):
|
||||
dropped += 1
|
||||
log.info(
|
||||
"Dropped %d out of %d stale databases detected", dropped, total
|
||||
)
|
||||
|
||||
|
||||
@follower_url_from_main.for_db("oracle")
|
||||
def _oracle_follower_url_from_main(url, ident):
|
||||
url = sa_url.make_url(url)
|
||||
return url.set(username=ident, password="xe")
|
||||
|
||||
|
||||
@temp_table_keyword_args.for_db("oracle")
|
||||
def _oracle_temp_table_keyword_args(cfg, eng):
|
||||
return {
|
||||
"prefixes": ["GLOBAL TEMPORARY"],
|
||||
"oracle_on_commit": "PRESERVE ROWS",
|
||||
}
|
||||
|
||||
|
||||
@set_default_schema_on_connection.for_db("oracle")
|
||||
def _oracle_set_default_schema_on_connection(
|
||||
cfg, dbapi_connection, schema_name
|
||||
):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("ALTER SESSION SET CURRENT_SCHEMA=%s" % schema_name)
|
||||
cursor.close()
|
||||
|
||||
|
||||
@update_db_opts.for_db("oracle")
|
||||
def _update_db_opts(db_url, db_opts, options):
|
||||
"""Set database options (db_opts) for a test database that we created."""
|
||||
if (
|
||||
options.oracledb_thick_mode
|
||||
and sa_url.make_url(db_url).get_driver_name() == "oracledb"
|
||||
):
|
||||
db_opts["thick_mode"] = True
|
|
@ -0,0 +1,287 @@
|
|||
# dialects/oracle/types.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ... import exc
|
||||
from ...sql import sqltypes
|
||||
from ...types import NVARCHAR
|
||||
from ...types import VARCHAR
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
|
||||
|
||||
class RAW(sqltypes._Binary):
|
||||
__visit_name__ = "RAW"
|
||||
|
||||
|
||||
OracleRaw = RAW
|
||||
|
||||
|
||||
class NCLOB(sqltypes.Text):
|
||||
__visit_name__ = "NCLOB"
|
||||
|
||||
|
||||
class VARCHAR2(VARCHAR):
|
||||
__visit_name__ = "VARCHAR2"
|
||||
|
||||
|
||||
NVARCHAR2 = NVARCHAR
|
||||
|
||||
|
||||
class NUMBER(sqltypes.Numeric, sqltypes.Integer):
|
||||
__visit_name__ = "NUMBER"
|
||||
|
||||
def __init__(self, precision=None, scale=None, asdecimal=None):
|
||||
if asdecimal is None:
|
||||
asdecimal = bool(scale and scale > 0)
|
||||
|
||||
super().__init__(precision=precision, scale=scale, asdecimal=asdecimal)
|
||||
|
||||
def adapt(self, impltype):
|
||||
ret = super().adapt(impltype)
|
||||
# leave a hint for the DBAPI handler
|
||||
ret._is_oracle_number = True
|
||||
return ret
|
||||
|
||||
@property
|
||||
def _type_affinity(self):
|
||||
if bool(self.scale and self.scale > 0):
|
||||
return sqltypes.Numeric
|
||||
else:
|
||||
return sqltypes.Integer
|
||||
|
||||
|
||||
class FLOAT(sqltypes.FLOAT):
|
||||
"""Oracle FLOAT.
|
||||
|
||||
This is the same as :class:`_sqltypes.FLOAT` except that
|
||||
an Oracle-specific :paramref:`_oracle.FLOAT.binary_precision`
|
||||
parameter is accepted, and
|
||||
the :paramref:`_sqltypes.Float.precision` parameter is not accepted.
|
||||
|
||||
Oracle FLOAT types indicate precision in terms of "binary precision", which
|
||||
defaults to 126. For a REAL type, the value is 63. This parameter does not
|
||||
cleanly map to a specific number of decimal places but is roughly
|
||||
equivalent to the desired number of decimal places divided by 0.3103.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "FLOAT"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
binary_precision=None,
|
||||
asdecimal=False,
|
||||
decimal_return_scale=None,
|
||||
):
|
||||
r"""
|
||||
Construct a FLOAT
|
||||
|
||||
:param binary_precision: Oracle binary precision value to be rendered
|
||||
in DDL. This may be approximated to the number of decimal characters
|
||||
using the formula "decimal precision = 0.30103 * binary precision".
|
||||
The default value used by Oracle for FLOAT / DOUBLE PRECISION is 126.
|
||||
|
||||
:param asdecimal: See :paramref:`_sqltypes.Float.asdecimal`
|
||||
|
||||
:param decimal_return_scale: See
|
||||
:paramref:`_sqltypes.Float.decimal_return_scale`
|
||||
|
||||
"""
|
||||
super().__init__(
|
||||
asdecimal=asdecimal, decimal_return_scale=decimal_return_scale
|
||||
)
|
||||
self.binary_precision = binary_precision
|
||||
|
||||
|
||||
class BINARY_DOUBLE(sqltypes.Double):
|
||||
__visit_name__ = "BINARY_DOUBLE"
|
||||
|
||||
|
||||
class BINARY_FLOAT(sqltypes.Float):
|
||||
__visit_name__ = "BINARY_FLOAT"
|
||||
|
||||
|
||||
class BFILE(sqltypes.LargeBinary):
|
||||
__visit_name__ = "BFILE"
|
||||
|
||||
|
||||
class LONG(sqltypes.Text):
|
||||
__visit_name__ = "LONG"
|
||||
|
||||
|
||||
class _OracleDateLiteralRender:
|
||||
def _literal_processor_datetime(self, dialect):
|
||||
def process(value):
|
||||
if getattr(value, "microsecond", None):
|
||||
value = (
|
||||
f"""TO_TIMESTAMP"""
|
||||
f"""('{value.isoformat().replace("T", " ")}', """
|
||||
"""'YYYY-MM-DD HH24:MI:SS.FF')"""
|
||||
)
|
||||
else:
|
||||
value = (
|
||||
f"""TO_DATE"""
|
||||
f"""('{value.isoformat().replace("T", " ")}', """
|
||||
"""'YYYY-MM-DD HH24:MI:SS')"""
|
||||
)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def _literal_processor_date(self, dialect):
|
||||
def process(value):
|
||||
if getattr(value, "microsecond", None):
|
||||
value = (
|
||||
f"""TO_TIMESTAMP"""
|
||||
f"""('{value.isoformat().split("T")[0]}', """
|
||||
"""'YYYY-MM-DD')"""
|
||||
)
|
||||
else:
|
||||
value = (
|
||||
f"""TO_DATE"""
|
||||
f"""('{value.isoformat().split("T")[0]}', """
|
||||
"""'YYYY-MM-DD')"""
|
||||
)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class DATE(_OracleDateLiteralRender, sqltypes.DateTime):
|
||||
"""Provide the oracle DATE type.
|
||||
|
||||
This type has no special Python behavior, except that it subclasses
|
||||
:class:`_types.DateTime`; this is to suit the fact that the Oracle
|
||||
``DATE`` type supports a time value.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "DATE"
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
return self._literal_processor_datetime(dialect)
|
||||
|
||||
def _compare_type_affinity(self, other):
|
||||
return other._type_affinity in (sqltypes.DateTime, sqltypes.Date)
|
||||
|
||||
|
||||
class _OracleDate(_OracleDateLiteralRender, sqltypes.Date):
|
||||
def literal_processor(self, dialect):
|
||||
return self._literal_processor_date(dialect)
|
||||
|
||||
|
||||
class INTERVAL(sqltypes.NativeForEmulated, sqltypes._AbstractInterval):
|
||||
__visit_name__ = "INTERVAL"
|
||||
|
||||
def __init__(self, day_precision=None, second_precision=None):
|
||||
"""Construct an INTERVAL.
|
||||
|
||||
Note that only DAY TO SECOND intervals are currently supported.
|
||||
This is due to a lack of support for YEAR TO MONTH intervals
|
||||
within available DBAPIs.
|
||||
|
||||
:param day_precision: the day precision value. this is the number of
|
||||
digits to store for the day field. Defaults to "2"
|
||||
:param second_precision: the second precision value. this is the
|
||||
number of digits to store for the fractional seconds field.
|
||||
Defaults to "6".
|
||||
|
||||
"""
|
||||
self.day_precision = day_precision
|
||||
self.second_precision = second_precision
|
||||
|
||||
@classmethod
|
||||
def _adapt_from_generic_interval(cls, interval):
|
||||
return INTERVAL(
|
||||
day_precision=interval.day_precision,
|
||||
second_precision=interval.second_precision,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(
|
||||
cls, interval: sqltypes.Interval, **kw # type: ignore[override]
|
||||
):
|
||||
return INTERVAL(
|
||||
day_precision=interval.day_precision,
|
||||
second_precision=interval.second_precision,
|
||||
)
|
||||
|
||||
@property
|
||||
def _type_affinity(self):
|
||||
return sqltypes.Interval
|
||||
|
||||
def as_generic(self, allow_nulltype=False):
|
||||
return sqltypes.Interval(
|
||||
native=True,
|
||||
second_precision=self.second_precision,
|
||||
day_precision=self.day_precision,
|
||||
)
|
||||
|
||||
@property
|
||||
def python_type(self) -> Type[dt.timedelta]:
|
||||
return dt.timedelta
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_LiteralProcessorType[dt.timedelta]]:
|
||||
def process(value: dt.timedelta) -> str:
|
||||
return f"NUMTODSINTERVAL({value.total_seconds()}, 'SECOND')"
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""Oracle implementation of ``TIMESTAMP``, which supports additional
|
||||
Oracle-specific modes
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, timezone: bool = False, local_timezone: bool = False):
|
||||
"""Construct a new :class:`_oracle.TIMESTAMP`.
|
||||
|
||||
:param timezone: boolean. Indicates that the TIMESTAMP type should
|
||||
use Oracle's ``TIMESTAMP WITH TIME ZONE`` datatype.
|
||||
|
||||
:param local_timezone: boolean. Indicates that the TIMESTAMP type
|
||||
should use Oracle's ``TIMESTAMP WITH LOCAL TIME ZONE`` datatype.
|
||||
|
||||
|
||||
"""
|
||||
if timezone and local_timezone:
|
||||
raise exc.ArgumentError(
|
||||
"timezone and local_timezone are mutually exclusive"
|
||||
)
|
||||
super().__init__(timezone=timezone)
|
||||
self.local_timezone = local_timezone
|
||||
|
||||
|
||||
class ROWID(sqltypes.TypeEngine):
|
||||
"""Oracle ROWID type.
|
||||
|
||||
When used in a cast() or similar, generates ROWID.
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "ROWID"
|
||||
|
||||
|
||||
class _OracleBoolean(sqltypes.Boolean):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.NUMBER
|
|
@ -0,0 +1,167 @@
|
|||
# dialects/postgresql/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from types import ModuleType
|
||||
|
||||
from . import array as arraylib # noqa # must be above base and other dialects
|
||||
from . import asyncpg # noqa
|
||||
from . import base
|
||||
from . import pg8000 # noqa
|
||||
from . import psycopg # noqa
|
||||
from . import psycopg2 # noqa
|
||||
from . import psycopg2cffi # noqa
|
||||
from .array import All
|
||||
from .array import Any
|
||||
from .array import ARRAY
|
||||
from .array import array
|
||||
from .base import BIGINT
|
||||
from .base import BOOLEAN
|
||||
from .base import CHAR
|
||||
from .base import DATE
|
||||
from .base import DOMAIN
|
||||
from .base import DOUBLE_PRECISION
|
||||
from .base import FLOAT
|
||||
from .base import INTEGER
|
||||
from .base import NUMERIC
|
||||
from .base import REAL
|
||||
from .base import SMALLINT
|
||||
from .base import TEXT
|
||||
from .base import UUID
|
||||
from .base import VARCHAR
|
||||
from .dml import Insert
|
||||
from .dml import insert
|
||||
from .ext import aggregate_order_by
|
||||
from .ext import array_agg
|
||||
from .ext import ExcludeConstraint
|
||||
from .ext import phraseto_tsquery
|
||||
from .ext import plainto_tsquery
|
||||
from .ext import to_tsquery
|
||||
from .ext import to_tsvector
|
||||
from .ext import ts_headline
|
||||
from .ext import websearch_to_tsquery
|
||||
from .hstore import HSTORE
|
||||
from .hstore import hstore
|
||||
from .json import JSON
|
||||
from .json import JSONB
|
||||
from .json import JSONPATH
|
||||
from .named_types import CreateDomainType
|
||||
from .named_types import CreateEnumType
|
||||
from .named_types import DropDomainType
|
||||
from .named_types import DropEnumType
|
||||
from .named_types import ENUM
|
||||
from .named_types import NamedType
|
||||
from .ranges import AbstractMultiRange
|
||||
from .ranges import AbstractRange
|
||||
from .ranges import AbstractSingleRange
|
||||
from .ranges import DATEMULTIRANGE
|
||||
from .ranges import DATERANGE
|
||||
from .ranges import INT4MULTIRANGE
|
||||
from .ranges import INT4RANGE
|
||||
from .ranges import INT8MULTIRANGE
|
||||
from .ranges import INT8RANGE
|
||||
from .ranges import MultiRange
|
||||
from .ranges import NUMMULTIRANGE
|
||||
from .ranges import NUMRANGE
|
||||
from .ranges import Range
|
||||
from .ranges import TSMULTIRANGE
|
||||
from .ranges import TSRANGE
|
||||
from .ranges import TSTZMULTIRANGE
|
||||
from .ranges import TSTZRANGE
|
||||
from .types import BIT
|
||||
from .types import BYTEA
|
||||
from .types import CIDR
|
||||
from .types import CITEXT
|
||||
from .types import INET
|
||||
from .types import INTERVAL
|
||||
from .types import MACADDR
|
||||
from .types import MACADDR8
|
||||
from .types import MONEY
|
||||
from .types import OID
|
||||
from .types import REGCLASS
|
||||
from .types import REGCONFIG
|
||||
from .types import TIME
|
||||
from .types import TIMESTAMP
|
||||
from .types import TSQUERY
|
||||
from .types import TSVECTOR
|
||||
|
||||
|
||||
# Alias psycopg also as psycopg_async
|
||||
psycopg_async = type(
|
||||
"psycopg_async", (ModuleType,), {"dialect": psycopg.dialect_async}
|
||||
)
|
||||
|
||||
base.dialect = dialect = psycopg2.dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
"INTEGER",
|
||||
"BIGINT",
|
||||
"SMALLINT",
|
||||
"VARCHAR",
|
||||
"CHAR",
|
||||
"TEXT",
|
||||
"NUMERIC",
|
||||
"FLOAT",
|
||||
"REAL",
|
||||
"INET",
|
||||
"CIDR",
|
||||
"CITEXT",
|
||||
"UUID",
|
||||
"BIT",
|
||||
"MACADDR",
|
||||
"MACADDR8",
|
||||
"MONEY",
|
||||
"OID",
|
||||
"REGCLASS",
|
||||
"REGCONFIG",
|
||||
"TSQUERY",
|
||||
"TSVECTOR",
|
||||
"DOUBLE_PRECISION",
|
||||
"TIMESTAMP",
|
||||
"TIME",
|
||||
"DATE",
|
||||
"BYTEA",
|
||||
"BOOLEAN",
|
||||
"INTERVAL",
|
||||
"ARRAY",
|
||||
"ENUM",
|
||||
"DOMAIN",
|
||||
"dialect",
|
||||
"array",
|
||||
"HSTORE",
|
||||
"hstore",
|
||||
"INT4RANGE",
|
||||
"INT8RANGE",
|
||||
"NUMRANGE",
|
||||
"DATERANGE",
|
||||
"INT4MULTIRANGE",
|
||||
"INT8MULTIRANGE",
|
||||
"NUMMULTIRANGE",
|
||||
"DATEMULTIRANGE",
|
||||
"TSVECTOR",
|
||||
"TSRANGE",
|
||||
"TSTZRANGE",
|
||||
"TSMULTIRANGE",
|
||||
"TSTZMULTIRANGE",
|
||||
"JSON",
|
||||
"JSONB",
|
||||
"JSONPATH",
|
||||
"Any",
|
||||
"All",
|
||||
"DropEnumType",
|
||||
"DropDomainType",
|
||||
"CreateDomainType",
|
||||
"NamedType",
|
||||
"CreateEnumType",
|
||||
"ExcludeConstraint",
|
||||
"Range",
|
||||
"aggregate_order_by",
|
||||
"array_agg",
|
||||
"insert",
|
||||
"Insert",
|
||||
)
|
|
@ -0,0 +1,187 @@
|
|||
# dialects/postgresql/_psycopg_common.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
import decimal
|
||||
|
||||
from .array import ARRAY as PGARRAY
|
||||
from .base import _DECIMAL_TYPES
|
||||
from .base import _FLOAT_TYPES
|
||||
from .base import _INT_TYPES
|
||||
from .base import PGDialect
|
||||
from .base import PGExecutionContext
|
||||
from .hstore import HSTORE
|
||||
from .pg_catalog import _SpaceVector
|
||||
from .pg_catalog import INT2VECTOR
|
||||
from .pg_catalog import OIDVECTOR
|
||||
from ... import exc
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...engine import processors
|
||||
|
||||
_server_side_id = util.counter()
|
||||
|
||||
|
||||
class _PsycopgNumeric(sqltypes.Numeric):
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.asdecimal:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
return processors.to_decimal_processor_factory(
|
||||
decimal.Decimal, self._effective_decimal_return_scale
|
||||
)
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
# psycopg returns Decimal natively for 1700
|
||||
return None
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype
|
||||
)
|
||||
else:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
# psycopg returns float natively for 701
|
||||
return None
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
return processors.to_float
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype
|
||||
)
|
||||
|
||||
|
||||
class _PsycopgFloat(_PsycopgNumeric):
|
||||
__visit_name__ = "float"
|
||||
|
||||
|
||||
class _PsycopgHStore(HSTORE):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect._has_native_hstore:
|
||||
return None
|
||||
else:
|
||||
return super().bind_processor(dialect)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect._has_native_hstore:
|
||||
return None
|
||||
else:
|
||||
return super().result_processor(dialect, coltype)
|
||||
|
||||
|
||||
class _PsycopgARRAY(PGARRAY):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PsycopgINT2VECTOR(_SpaceVector, INT2VECTOR):
|
||||
pass
|
||||
|
||||
|
||||
class _PsycopgOIDVECTOR(_SpaceVector, OIDVECTOR):
|
||||
pass
|
||||
|
||||
|
||||
class _PGExecutionContext_common_psycopg(PGExecutionContext):
|
||||
def create_server_side_cursor(self):
|
||||
# use server-side cursors:
|
||||
# psycopg
|
||||
# https://www.psycopg.org/psycopg3/docs/advanced/cursors.html#server-side-cursors
|
||||
# psycopg2
|
||||
# https://www.psycopg.org/docs/usage.html#server-side-cursors
|
||||
ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
|
||||
return self._dbapi_connection.cursor(ident)
|
||||
|
||||
|
||||
class _PGDialect_common_psycopg(PGDialect):
|
||||
supports_statement_cache = True
|
||||
supports_server_side_cursors = True
|
||||
|
||||
default_paramstyle = "pyformat"
|
||||
|
||||
_has_native_hstore = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
sqltypes.Numeric: _PsycopgNumeric,
|
||||
sqltypes.Float: _PsycopgFloat,
|
||||
HSTORE: _PsycopgHStore,
|
||||
sqltypes.ARRAY: _PsycopgARRAY,
|
||||
INT2VECTOR: _PsycopgINT2VECTOR,
|
||||
OIDVECTOR: _PsycopgOIDVECTOR,
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client_encoding=None,
|
||||
use_native_hstore=True,
|
||||
**kwargs,
|
||||
):
|
||||
PGDialect.__init__(self, **kwargs)
|
||||
if not use_native_hstore:
|
||||
self._has_native_hstore = False
|
||||
self.use_native_hstore = use_native_hstore
|
||||
self.client_encoding = client_encoding
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user", database="dbname")
|
||||
|
||||
multihosts, multiports = self._split_multihost_from_url(url)
|
||||
|
||||
if opts or url.query:
|
||||
if not opts:
|
||||
opts = {}
|
||||
if "port" in opts:
|
||||
opts["port"] = int(opts["port"])
|
||||
opts.update(url.query)
|
||||
|
||||
if multihosts:
|
||||
opts["host"] = ",".join(multihosts)
|
||||
comma_ports = ",".join(str(p) if p else "" for p in multiports)
|
||||
if comma_ports:
|
||||
opts["port"] = comma_ports
|
||||
return ([], opts)
|
||||
else:
|
||||
# no connection arguments whatsoever; psycopg2.connect()
|
||||
# requires that "dsn" be present as a blank string.
|
||||
return ([""], opts)
|
||||
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"AUTOCOMMIT",
|
||||
"READ COMMITTED",
|
||||
"READ UNCOMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"SERIALIZABLE",
|
||||
)
|
||||
|
||||
def set_deferrable(self, connection, value):
|
||||
connection.deferrable = value
|
||||
|
||||
def get_deferrable(self, connection):
|
||||
return connection.deferrable
|
||||
|
||||
def _do_autocommit(self, connection, value):
|
||||
connection.autocommit = value
|
||||
|
||||
def do_ping(self, dbapi_connection):
|
||||
cursor = None
|
||||
before_autocommit = dbapi_connection.autocommit
|
||||
|
||||
if not before_autocommit:
|
||||
dbapi_connection.autocommit = True
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
cursor.execute(self._dialect_specific_select_one)
|
||||
finally:
|
||||
cursor.close()
|
||||
if not before_autocommit and not dbapi_connection.closed:
|
||||
dbapi_connection.autocommit = before_autocommit
|
||||
|
||||
return True
|
|
@ -0,0 +1,424 @@
|
|||
# dialects/postgresql/array.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
|
||||
from .operators import CONTAINED_BY
|
||||
from .operators import CONTAINS
|
||||
from .operators import OVERLAP
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...sql import expression
|
||||
from ...sql import operators
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
|
||||
def Any(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.any` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.any(other, operator)
|
||||
|
||||
|
||||
def All(other, arrexpr, operator=operators.eq):
|
||||
"""A synonym for the ARRAY-level :meth:`.ARRAY.Comparator.all` method.
|
||||
See that method for details.
|
||||
|
||||
"""
|
||||
|
||||
return arrexpr.all(other, operator)
|
||||
|
||||
|
||||
class array(expression.ExpressionClauseList[_T]):
|
||||
"""A PostgreSQL ARRAY literal.
|
||||
|
||||
This is used to produce ARRAY literals in SQL expressions, e.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy import select, func
|
||||
|
||||
stmt = select(array([1,2]) + array([3,4,5]))
|
||||
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces the SQL::
|
||||
|
||||
SELECT ARRAY[%(param_1)s, %(param_2)s] ||
|
||||
ARRAY[%(param_3)s, %(param_4)s, %(param_5)s]) AS anon_1
|
||||
|
||||
An instance of :class:`.array` will always have the datatype
|
||||
:class:`_types.ARRAY`. The "inner" type of the array is inferred from
|
||||
the values present, unless the ``type_`` keyword argument is passed::
|
||||
|
||||
array(['foo', 'bar'], type_=CHAR)
|
||||
|
||||
Multidimensional arrays are produced by nesting :class:`.array` constructs.
|
||||
The dimensionality of the final :class:`_types.ARRAY`
|
||||
type is calculated by
|
||||
recursively adding the dimensions of the inner :class:`_types.ARRAY`
|
||||
type::
|
||||
|
||||
stmt = select(
|
||||
array([
|
||||
array([1, 2]), array([3, 4]), array([column('q'), column('x')])
|
||||
])
|
||||
)
|
||||
print(stmt.compile(dialect=postgresql.dialect()))
|
||||
|
||||
Produces::
|
||||
|
||||
SELECT ARRAY[ARRAY[%(param_1)s, %(param_2)s],
|
||||
ARRAY[%(param_3)s, %(param_4)s], ARRAY[q, x]] AS anon_1
|
||||
|
||||
.. versionadded:: 1.3.6 added support for multidimensional array literals
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_postgresql.ARRAY`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "array"
|
||||
|
||||
stringify_dialect = "postgresql"
|
||||
inherit_cache = True
|
||||
|
||||
def __init__(self, clauses, **kw):
|
||||
type_arg = kw.pop("type_", None)
|
||||
super().__init__(operators.comma_op, *clauses, **kw)
|
||||
|
||||
self._type_tuple = [arg.type for arg in self.clauses]
|
||||
|
||||
main_type = (
|
||||
type_arg
|
||||
if type_arg is not None
|
||||
else self._type_tuple[0] if self._type_tuple else sqltypes.NULLTYPE
|
||||
)
|
||||
|
||||
if isinstance(main_type, ARRAY):
|
||||
self.type = ARRAY(
|
||||
main_type.item_type,
|
||||
dimensions=(
|
||||
main_type.dimensions + 1
|
||||
if main_type.dimensions is not None
|
||||
else 2
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.type = ARRAY(main_type)
|
||||
|
||||
@property
|
||||
def _select_iterable(self):
|
||||
return (self,)
|
||||
|
||||
def _bind_param(self, operator, obj, _assume_scalar=False, type_=None):
|
||||
if _assume_scalar or operator is operators.getitem:
|
||||
return expression.BindParameter(
|
||||
None,
|
||||
obj,
|
||||
_compared_to_operator=operator,
|
||||
type_=type_,
|
||||
_compared_to_type=self.type,
|
||||
unique=True,
|
||||
)
|
||||
|
||||
else:
|
||||
return array(
|
||||
[
|
||||
self._bind_param(
|
||||
operator, o, _assume_scalar=True, type_=type_
|
||||
)
|
||||
for o in obj
|
||||
]
|
||||
)
|
||||
|
||||
def self_group(self, against=None):
|
||||
if against in (operators.any_op, operators.all_op, operators.getitem):
|
||||
return expression.Grouping(self)
|
||||
else:
|
||||
return self
|
||||
|
||||
|
||||
class ARRAY(sqltypes.ARRAY):
|
||||
"""PostgreSQL ARRAY type.
|
||||
|
||||
The :class:`_postgresql.ARRAY` type is constructed in the same way
|
||||
as the core :class:`_types.ARRAY` type; a member type is required, and a
|
||||
number of dimensions is recommended if the type is to be used for more
|
||||
than one dimension::
|
||||
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
mytable = Table("mytable", metadata,
|
||||
Column("data", postgresql.ARRAY(Integer, dimensions=2))
|
||||
)
|
||||
|
||||
The :class:`_postgresql.ARRAY` type provides all operations defined on the
|
||||
core :class:`_types.ARRAY` type, including support for "dimensions",
|
||||
indexed access, and simple matching such as
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`. :class:`_postgresql.ARRAY`
|
||||
class also
|
||||
provides PostgreSQL-specific methods for containment operations, including
|
||||
:meth:`.postgresql.ARRAY.Comparator.contains`
|
||||
:meth:`.postgresql.ARRAY.Comparator.contained_by`, and
|
||||
:meth:`.postgresql.ARRAY.Comparator.overlap`, e.g.::
|
||||
|
||||
mytable.c.data.contains([1, 2])
|
||||
|
||||
The :class:`_postgresql.ARRAY` type may not be supported on all
|
||||
PostgreSQL DBAPIs; it is currently known to work on psycopg2 only.
|
||||
|
||||
Additionally, the :class:`_postgresql.ARRAY`
|
||||
type does not work directly in
|
||||
conjunction with the :class:`.ENUM` type. For a workaround, see the
|
||||
special type at :ref:`postgresql_array_of_enum`.
|
||||
|
||||
.. container:: topic
|
||||
|
||||
**Detecting Changes in ARRAY columns when using the ORM**
|
||||
|
||||
The :class:`_postgresql.ARRAY` type, when used with the SQLAlchemy ORM,
|
||||
does not detect in-place mutations to the array. In order to detect
|
||||
these, the :mod:`sqlalchemy.ext.mutable` extension must be used, using
|
||||
the :class:`.MutableList` class::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.ext.mutable import MutableList
|
||||
|
||||
class SomeOrmClass(Base):
|
||||
# ...
|
||||
|
||||
data = Column(MutableList.as_mutable(ARRAY(Integer)))
|
||||
|
||||
This extension will allow "in-place" changes such to the array
|
||||
such as ``.append()`` to produce events which will be detected by the
|
||||
unit of work. Note that changes to elements **inside** the array,
|
||||
including subarrays that are mutated in place, are **not** detected.
|
||||
|
||||
Alternatively, assigning a new array value to an ORM element that
|
||||
replaces the old one will always trigger a change event.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.ARRAY` - base array type
|
||||
|
||||
:class:`_postgresql.array` - produces a literal array value.
|
||||
|
||||
"""
|
||||
|
||||
class Comparator(sqltypes.ARRAY.Comparator):
|
||||
"""Define comparison operations for :class:`_types.ARRAY`.
|
||||
|
||||
Note that these operations are in addition to those provided
|
||||
by the base :class:`.types.ARRAY.Comparator` class, including
|
||||
:meth:`.types.ARRAY.Comparator.any` and
|
||||
:meth:`.types.ARRAY.Comparator.all`.
|
||||
|
||||
"""
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if elements are a superset of the
|
||||
elements of the argument array expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if elements are a proper subset of the
|
||||
elements of the argument array expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def overlap(self, other):
|
||||
"""Boolean expression. Test if array has elements in common with
|
||||
an argument array expression.
|
||||
"""
|
||||
return self.operate(OVERLAP, other, result_type=sqltypes.Boolean)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
item_type: _TypeEngineArgument[Any],
|
||||
as_tuple: bool = False,
|
||||
dimensions: Optional[int] = None,
|
||||
zero_indexes: bool = False,
|
||||
):
|
||||
"""Construct an ARRAY.
|
||||
|
||||
E.g.::
|
||||
|
||||
Column('myarray', ARRAY(Integer))
|
||||
|
||||
Arguments are:
|
||||
|
||||
:param item_type: The data type of items of this array. Note that
|
||||
dimensionality is irrelevant here, so multi-dimensional arrays like
|
||||
``INTEGER[][]``, are constructed as ``ARRAY(Integer)``, not as
|
||||
``ARRAY(ARRAY(Integer))`` or such.
|
||||
|
||||
:param as_tuple=False: Specify whether return results
|
||||
should be converted to tuples from lists. DBAPIs such
|
||||
as psycopg2 return lists by default. When tuples are
|
||||
returned, the results are hashable.
|
||||
|
||||
:param dimensions: if non-None, the ARRAY will assume a fixed
|
||||
number of dimensions. This will cause the DDL emitted for this
|
||||
ARRAY to include the exact number of bracket clauses ``[]``,
|
||||
and will also optimize the performance of the type overall.
|
||||
Note that PG arrays are always implicitly "non-dimensioned",
|
||||
meaning they can store any number of dimensions no matter how
|
||||
they were declared.
|
||||
|
||||
:param zero_indexes=False: when True, index values will be converted
|
||||
between Python zero-based and PostgreSQL one-based indexes, e.g.
|
||||
a value of one will be added to all index values before passing
|
||||
to the database.
|
||||
|
||||
"""
|
||||
if isinstance(item_type, ARRAY):
|
||||
raise ValueError(
|
||||
"Do not nest ARRAY types; ARRAY(basetype) "
|
||||
"handles multi-dimensional arrays of basetype"
|
||||
)
|
||||
if isinstance(item_type, type):
|
||||
item_type = item_type()
|
||||
self.item_type = item_type
|
||||
self.as_tuple = as_tuple
|
||||
self.dimensions = dimensions
|
||||
self.zero_indexes = zero_indexes
|
||||
|
||||
@property
|
||||
def hashable(self):
|
||||
return self.as_tuple
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return list
|
||||
|
||||
def compare_values(self, x, y):
|
||||
return x == y
|
||||
|
||||
@util.memoized_property
|
||||
def _against_native_enum(self):
|
||||
return (
|
||||
isinstance(self.item_type, sqltypes.Enum)
|
||||
and self.item_type.native_enum
|
||||
)
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
item_proc = self.item_type.dialect_impl(dialect).literal_processor(
|
||||
dialect
|
||||
)
|
||||
if item_proc is None:
|
||||
return None
|
||||
|
||||
def to_str(elements):
|
||||
return f"ARRAY[{', '.join(elements)}]"
|
||||
|
||||
def process(value):
|
||||
inner = self._apply_item_processor(
|
||||
value, item_proc, self.dimensions, to_str
|
||||
)
|
||||
return inner
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
item_proc = self.item_type.dialect_impl(dialect).bind_processor(
|
||||
dialect
|
||||
)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return self._apply_item_processor(
|
||||
value, item_proc, self.dimensions, list
|
||||
)
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
item_proc = self.item_type.dialect_impl(dialect).result_processor(
|
||||
dialect, coltype
|
||||
)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
else:
|
||||
return self._apply_item_processor(
|
||||
value,
|
||||
item_proc,
|
||||
self.dimensions,
|
||||
tuple if self.as_tuple else list,
|
||||
)
|
||||
|
||||
if self._against_native_enum:
|
||||
super_rp = process
|
||||
pattern = re.compile(r"^{(.*)}$")
|
||||
|
||||
def handle_raw_string(value):
|
||||
inner = pattern.match(value).group(1)
|
||||
return _split_enum_values(inner)
|
||||
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
# isinstance(value, str) is required to handle
|
||||
# the case where a TypeDecorator for and Array of Enum is
|
||||
# used like was required in sa < 1.3.17
|
||||
return super_rp(
|
||||
handle_raw_string(value)
|
||||
if isinstance(value, str)
|
||||
else value
|
||||
)
|
||||
|
||||
return process
|
||||
|
||||
|
||||
def _split_enum_values(array_string):
|
||||
if '"' not in array_string:
|
||||
# no escape char is present so it can just split on the comma
|
||||
return array_string.split(",") if array_string else []
|
||||
|
||||
# handles quoted strings from:
|
||||
# r'abc,"quoted","also\\\\quoted", "quoted, comma", "esc \" quot", qpr'
|
||||
# returns
|
||||
# ['abc', 'quoted', 'also\\quoted', 'quoted, comma', 'esc " quot', 'qpr']
|
||||
text = array_string.replace(r"\"", "_$ESC_QUOTE$_")
|
||||
text = text.replace(r"\\", "\\")
|
||||
result = []
|
||||
on_quotes = re.split(r'(")', text)
|
||||
in_quotes = False
|
||||
for tok in on_quotes:
|
||||
if tok == '"':
|
||||
in_quotes = not in_quotes
|
||||
elif in_quotes:
|
||||
result.append(tok.replace("_$ESC_QUOTE$_", '"'))
|
||||
else:
|
||||
result.extend(re.findall(r"([^\s,]+),?", tok))
|
||||
return result
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,310 @@
|
|||
# dialects/postgresql/dml.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
|
||||
from . import ext
|
||||
from .._typing import _OnConflictConstraintT
|
||||
from .._typing import _OnConflictIndexElementsT
|
||||
from .._typing import _OnConflictIndexWhereT
|
||||
from .._typing import _OnConflictSetT
|
||||
from .._typing import _OnConflictWhereT
|
||||
from ... import util
|
||||
from ...sql import coercions
|
||||
from ...sql import roles
|
||||
from ...sql import schema
|
||||
from ...sql._typing import _DMLTableArgument
|
||||
from ...sql.base import _exclusive_against
|
||||
from ...sql.base import _generative
|
||||
from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.expression import alias
|
||||
from ...util.typing import Self
|
||||
|
||||
|
||||
__all__ = ("Insert", "insert")
|
||||
|
||||
|
||||
def insert(table: _DMLTableArgument) -> Insert:
|
||||
"""Construct a PostgreSQL-specific variant :class:`_postgresql.Insert`
|
||||
construct.
|
||||
|
||||
.. container:: inherited_member
|
||||
|
||||
The :func:`sqlalchemy.dialects.postgresql.insert` function creates
|
||||
a :class:`sqlalchemy.dialects.postgresql.Insert`. This class is based
|
||||
on the dialect-agnostic :class:`_sql.Insert` construct which may
|
||||
be constructed using the :func:`_sql.insert` function in
|
||||
SQLAlchemy Core.
|
||||
|
||||
The :class:`_postgresql.Insert` construct includes additional methods
|
||||
:meth:`_postgresql.Insert.on_conflict_do_update`,
|
||||
:meth:`_postgresql.Insert.on_conflict_do_nothing`.
|
||||
|
||||
"""
|
||||
return Insert(table)
|
||||
|
||||
|
||||
class Insert(StandardInsert):
|
||||
"""PostgreSQL-specific implementation of INSERT.
|
||||
|
||||
Adds methods for PG-specific syntaxes such as ON CONFLICT.
|
||||
|
||||
The :class:`_postgresql.Insert` object is created using the
|
||||
:func:`sqlalchemy.dialects.postgresql.insert` function.
|
||||
|
||||
"""
|
||||
|
||||
stringify_dialect = "postgresql"
|
||||
inherit_cache = False
|
||||
|
||||
@util.memoized_property
|
||||
def excluded(
|
||||
self,
|
||||
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
|
||||
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
|
||||
|
||||
PG's ON CONFLICT clause allows reference to the row that would
|
||||
be inserted, known as ``excluded``. This attribute provides
|
||||
all columns in this row to be referenceable.
|
||||
|
||||
.. tip:: The :attr:`_postgresql.Insert.excluded` attribute is an
|
||||
instance of :class:`_expression.ColumnCollection`, which provides
|
||||
an interface the same as that of the :attr:`_schema.Table.c`
|
||||
collection described at :ref:`metadata_tables_and_columns`.
|
||||
With this collection, ordinary names are accessible like attributes
|
||||
(e.g. ``stmt.excluded.some_column``), but special names and
|
||||
dictionary method names should be accessed using indexed access,
|
||||
such as ``stmt.excluded["column name"]`` or
|
||||
``stmt.excluded["values"]``. See the docstring for
|
||||
:class:`_expression.ColumnCollection` for further examples.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict` - example of how
|
||||
to use :attr:`_expression.Insert.excluded`
|
||||
|
||||
"""
|
||||
return alias(self.table, name="excluded").columns
|
||||
|
||||
_on_conflict_exclusive = _exclusive_against(
|
||||
"_post_values_clause",
|
||||
msgs={
|
||||
"_post_values_clause": "This Insert construct already has "
|
||||
"an ON CONFLICT clause established"
|
||||
},
|
||||
)
|
||||
|
||||
@_generative
|
||||
@_on_conflict_exclusive
|
||||
def on_conflict_do_update(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
set_: _OnConflictSetT = None,
|
||||
where: _OnConflictWhereT = None,
|
||||
) -> Self:
|
||||
r"""
|
||||
Specifies a DO UPDATE SET action for ON CONFLICT clause.
|
||||
|
||||
Either the ``constraint`` or ``index_elements`` argument is
|
||||
required, but only one of these can be specified.
|
||||
|
||||
:param constraint:
|
||||
The name of a unique or exclusion constraint on the table,
|
||||
or the constraint object itself if it has a .name attribute.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`_schema.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
:param set\_:
|
||||
A dictionary or other mapping object
|
||||
where the keys are either names of columns in the target table,
|
||||
or :class:`_schema.Column` objects or other ORM-mapped columns
|
||||
matching that of the target table, and expressions or literals
|
||||
as values, specifying the ``SET`` actions to take.
|
||||
|
||||
.. versionadded:: 1.4 The
|
||||
:paramref:`_postgresql.Insert.on_conflict_do_update.set_`
|
||||
parameter supports :class:`_schema.Column` objects from the target
|
||||
:class:`_schema.Table` as keys.
|
||||
|
||||
.. warning:: This dictionary does **not** take into account
|
||||
Python-specified default UPDATE values or generation functions,
|
||||
e.g. those specified using :paramref:`_schema.Column.onupdate`.
|
||||
These values will not be exercised for an ON CONFLICT style of
|
||||
UPDATE, unless they are manually specified in the
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict`
|
||||
|
||||
"""
|
||||
self._post_values_clause = OnConflictDoUpdate(
|
||||
constraint, index_elements, index_where, set_, where
|
||||
)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
@_on_conflict_exclusive
|
||||
def on_conflict_do_nothing(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Specifies a DO NOTHING action for ON CONFLICT clause.
|
||||
|
||||
The ``constraint`` and ``index_elements`` arguments
|
||||
are optional, but only one of these can be specified.
|
||||
|
||||
:param constraint:
|
||||
The name of a unique or exclusion constraint on the table,
|
||||
or the constraint object itself if it has a .name attribute.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`_schema.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_insert_on_conflict`
|
||||
|
||||
"""
|
||||
self._post_values_clause = OnConflictDoNothing(
|
||||
constraint, index_elements, index_where
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class OnConflictClause(ClauseElement):
|
||||
stringify_dialect = "postgresql"
|
||||
|
||||
constraint_target: Optional[str]
|
||||
inferred_target_elements: _OnConflictIndexElementsT
|
||||
inferred_target_whereclause: _OnConflictIndexWhereT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
):
|
||||
if constraint is not None:
|
||||
if not isinstance(constraint, str) and isinstance(
|
||||
constraint,
|
||||
(schema.Constraint, ext.ExcludeConstraint),
|
||||
):
|
||||
constraint = getattr(constraint, "name") or constraint
|
||||
|
||||
if constraint is not None:
|
||||
if index_elements is not None:
|
||||
raise ValueError(
|
||||
"'constraint' and 'index_elements' are mutually exclusive"
|
||||
)
|
||||
|
||||
if isinstance(constraint, str):
|
||||
self.constraint_target = constraint
|
||||
self.inferred_target_elements = None
|
||||
self.inferred_target_whereclause = None
|
||||
elif isinstance(constraint, schema.Index):
|
||||
index_elements = constraint.expressions
|
||||
index_where = constraint.dialect_options["postgresql"].get(
|
||||
"where"
|
||||
)
|
||||
elif isinstance(constraint, ext.ExcludeConstraint):
|
||||
index_elements = constraint.columns
|
||||
index_where = constraint.where
|
||||
else:
|
||||
index_elements = constraint.columns
|
||||
index_where = constraint.dialect_options["postgresql"].get(
|
||||
"where"
|
||||
)
|
||||
|
||||
if index_elements is not None:
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
elif constraint is None:
|
||||
self.constraint_target = self.inferred_target_elements = (
|
||||
self.inferred_target_whereclause
|
||||
) = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_nothing"
|
||||
|
||||
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_update"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
constraint: _OnConflictConstraintT = None,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
set_: _OnConflictSetT = None,
|
||||
where: _OnConflictWhereT = None,
|
||||
):
|
||||
super().__init__(
|
||||
constraint=constraint,
|
||||
index_elements=index_elements,
|
||||
index_where=index_where,
|
||||
)
|
||||
|
||||
if (
|
||||
self.inferred_target_elements is None
|
||||
and self.constraint_target is None
|
||||
):
|
||||
raise ValueError(
|
||||
"Either constraint or index_elements, "
|
||||
"but not both, must be specified unless DO NOTHING"
|
||||
)
|
||||
|
||||
if isinstance(set_, dict):
|
||||
if not set_:
|
||||
raise ValueError("set parameter dictionary must not be empty")
|
||||
elif isinstance(set_, ColumnCollection):
|
||||
set_ = dict(set_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"set parameter must be a non-empty dictionary "
|
||||
"or a ColumnCollection such as the `.c.` collection "
|
||||
"of a Table object"
|
||||
)
|
||||
self.update_values_to_set = [
|
||||
(coercions.expect(roles.DMLColumnRole, key), value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = where
|
|
@ -0,0 +1,496 @@
|
|||
# dialects/postgresql/ext.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
|
||||
from . import types
|
||||
from .array import ARRAY
|
||||
from ...sql import coercions
|
||||
from ...sql import elements
|
||||
from ...sql import expression
|
||||
from ...sql import functions
|
||||
from ...sql import roles
|
||||
from ...sql import schema
|
||||
from ...sql.schema import ColumnCollectionConstraint
|
||||
from ...sql.sqltypes import TEXT
|
||||
from ...sql.visitors import InternalTraversal
|
||||
|
||||
_T = TypeVar("_T", bound=Any)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql.visitors import _TraverseInternalsType
|
||||
|
||||
|
||||
class aggregate_order_by(expression.ColumnElement):
|
||||
"""Represent a PostgreSQL aggregate order by expression.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import aggregate_order_by
|
||||
expr = func.array_agg(aggregate_order_by(table.c.a, table.c.b.desc()))
|
||||
stmt = select(expr)
|
||||
|
||||
would represent the expression::
|
||||
|
||||
SELECT array_agg(a ORDER BY b DESC) FROM table;
|
||||
|
||||
Similarly::
|
||||
|
||||
expr = func.string_agg(
|
||||
table.c.a,
|
||||
aggregate_order_by(literal_column("','"), table.c.a)
|
||||
)
|
||||
stmt = select(expr)
|
||||
|
||||
Would represent::
|
||||
|
||||
SELECT string_agg(a, ',' ORDER BY a) FROM table;
|
||||
|
||||
.. versionchanged:: 1.2.13 - the ORDER BY argument may be multiple terms
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_functions.array_agg`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "aggregate_order_by"
|
||||
|
||||
stringify_dialect = "postgresql"
|
||||
_traverse_internals: _TraverseInternalsType = [
|
||||
("target", InternalTraversal.dp_clauseelement),
|
||||
("type", InternalTraversal.dp_type),
|
||||
("order_by", InternalTraversal.dp_clauseelement),
|
||||
]
|
||||
|
||||
def __init__(self, target, *order_by):
|
||||
self.target = coercions.expect(roles.ExpressionElementRole, target)
|
||||
self.type = self.target.type
|
||||
|
||||
_lob = len(order_by)
|
||||
if _lob == 0:
|
||||
raise TypeError("at least one ORDER BY element is required")
|
||||
elif _lob == 1:
|
||||
self.order_by = coercions.expect(
|
||||
roles.ExpressionElementRole, order_by[0]
|
||||
)
|
||||
else:
|
||||
self.order_by = elements.ClauseList(
|
||||
*order_by, _literal_as_text_role=roles.ExpressionElementRole
|
||||
)
|
||||
|
||||
def self_group(self, against=None):
|
||||
return self
|
||||
|
||||
def get_children(self, **kwargs):
|
||||
return self.target, self.order_by
|
||||
|
||||
def _copy_internals(self, clone=elements._clone, **kw):
|
||||
self.target = clone(self.target, **kw)
|
||||
self.order_by = clone(self.order_by, **kw)
|
||||
|
||||
@property
|
||||
def _from_objects(self):
|
||||
return self.target._from_objects + self.order_by._from_objects
|
||||
|
||||
|
||||
class ExcludeConstraint(ColumnCollectionConstraint):
|
||||
"""A table-level EXCLUDE constraint.
|
||||
|
||||
Defines an EXCLUDE constraint as described in the `PostgreSQL
|
||||
documentation`__.
|
||||
|
||||
__ https://www.postgresql.org/docs/current/static/sql-createtable.html#SQL-CREATETABLE-EXCLUDE
|
||||
|
||||
""" # noqa
|
||||
|
||||
__visit_name__ = "exclude_constraint"
|
||||
|
||||
where = None
|
||||
inherit_cache = False
|
||||
|
||||
create_drop_stringify_dialect = "postgresql"
|
||||
|
||||
@elements._document_text_coercion(
|
||||
"where",
|
||||
":class:`.ExcludeConstraint`",
|
||||
":paramref:`.ExcludeConstraint.where`",
|
||||
)
|
||||
def __init__(self, *elements, **kw):
|
||||
r"""
|
||||
Create an :class:`.ExcludeConstraint` object.
|
||||
|
||||
E.g.::
|
||||
|
||||
const = ExcludeConstraint(
|
||||
(Column('period'), '&&'),
|
||||
(Column('group'), '='),
|
||||
where=(Column('group') != 'some group'),
|
||||
ops={'group': 'my_operator_class'}
|
||||
)
|
||||
|
||||
The constraint is normally embedded into the :class:`_schema.Table`
|
||||
construct
|
||||
directly, or added later using :meth:`.append_constraint`::
|
||||
|
||||
some_table = Table(
|
||||
'some_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('period', TSRANGE()),
|
||||
Column('group', String)
|
||||
)
|
||||
|
||||
some_table.append_constraint(
|
||||
ExcludeConstraint(
|
||||
(some_table.c.period, '&&'),
|
||||
(some_table.c.group, '='),
|
||||
where=some_table.c.group != 'some group',
|
||||
name='some_table_excl_const',
|
||||
ops={'group': 'my_operator_class'}
|
||||
)
|
||||
)
|
||||
|
||||
The exclude constraint defined in this example requires the
|
||||
``btree_gist`` extension, that can be created using the
|
||||
command ``CREATE EXTENSION btree_gist;``.
|
||||
|
||||
:param \*elements:
|
||||
|
||||
A sequence of two tuples of the form ``(column, operator)`` where
|
||||
"column" is either a :class:`_schema.Column` object, or a SQL
|
||||
expression element (e.g. ``func.int8range(table.from, table.to)``)
|
||||
or the name of a column as string, and "operator" is a string
|
||||
containing the operator to use (e.g. `"&&"` or `"="`).
|
||||
|
||||
In order to specify a column name when a :class:`_schema.Column`
|
||||
object is not available, while ensuring
|
||||
that any necessary quoting rules take effect, an ad-hoc
|
||||
:class:`_schema.Column` or :func:`_expression.column`
|
||||
object should be used.
|
||||
The ``column`` may also be a string SQL expression when
|
||||
passed as :func:`_expression.literal_column` or
|
||||
:func:`_expression.text`
|
||||
|
||||
:param name:
|
||||
Optional, the in-database name of this constraint.
|
||||
|
||||
:param deferrable:
|
||||
Optional bool. If set, emit DEFERRABLE or NOT DEFERRABLE when
|
||||
issuing DDL for this constraint.
|
||||
|
||||
:param initially:
|
||||
Optional string. If set, emit INITIALLY <value> when issuing DDL
|
||||
for this constraint.
|
||||
|
||||
:param using:
|
||||
Optional string. If set, emit USING <index_method> when issuing DDL
|
||||
for this constraint. Defaults to 'gist'.
|
||||
|
||||
:param where:
|
||||
Optional SQL expression construct or literal SQL string.
|
||||
If set, emit WHERE <predicate> when issuing DDL
|
||||
for this constraint.
|
||||
|
||||
:param ops:
|
||||
Optional dictionary. Used to define operator classes for the
|
||||
elements; works the same way as that of the
|
||||
:ref:`postgresql_ops <postgresql_operator_classes>`
|
||||
parameter specified to the :class:`_schema.Index` construct.
|
||||
|
||||
.. versionadded:: 1.3.21
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_operator_classes` - general description of how
|
||||
PostgreSQL operator classes are specified.
|
||||
|
||||
"""
|
||||
columns = []
|
||||
render_exprs = []
|
||||
self.operators = {}
|
||||
|
||||
expressions, operators = zip(*elements)
|
||||
|
||||
for (expr, column, strname, add_element), operator in zip(
|
||||
coercions.expect_col_expression_collection(
|
||||
roles.DDLConstraintColumnRole, expressions
|
||||
),
|
||||
operators,
|
||||
):
|
||||
if add_element is not None:
|
||||
columns.append(add_element)
|
||||
|
||||
name = column.name if column is not None else strname
|
||||
|
||||
if name is not None:
|
||||
# backwards compat
|
||||
self.operators[name] = operator
|
||||
|
||||
render_exprs.append((expr, name, operator))
|
||||
|
||||
self._render_exprs = render_exprs
|
||||
|
||||
ColumnCollectionConstraint.__init__(
|
||||
self,
|
||||
*columns,
|
||||
name=kw.get("name"),
|
||||
deferrable=kw.get("deferrable"),
|
||||
initially=kw.get("initially"),
|
||||
)
|
||||
self.using = kw.get("using", "gist")
|
||||
where = kw.get("where")
|
||||
if where is not None:
|
||||
self.where = coercions.expect(roles.StatementOptionRole, where)
|
||||
|
||||
self.ops = kw.get("ops", {})
|
||||
|
||||
def _set_parent(self, table, **kw):
|
||||
super()._set_parent(table)
|
||||
|
||||
self._render_exprs = [
|
||||
(
|
||||
expr if not isinstance(expr, str) else table.c[expr],
|
||||
name,
|
||||
operator,
|
||||
)
|
||||
for expr, name, operator in (self._render_exprs)
|
||||
]
|
||||
|
||||
def _copy(self, target_table=None, **kw):
|
||||
elements = [
|
||||
(
|
||||
schema._copy_expression(expr, self.parent, target_table),
|
||||
operator,
|
||||
)
|
||||
for expr, _, operator in self._render_exprs
|
||||
]
|
||||
c = self.__class__(
|
||||
*elements,
|
||||
name=self.name,
|
||||
deferrable=self.deferrable,
|
||||
initially=self.initially,
|
||||
where=self.where,
|
||||
using=self.using,
|
||||
)
|
||||
c.dispatch._update(self.dispatch)
|
||||
return c
|
||||
|
||||
|
||||
def array_agg(*arg, **kw):
|
||||
"""PostgreSQL-specific form of :class:`_functions.array_agg`, ensures
|
||||
return type is :class:`_postgresql.ARRAY` and not
|
||||
the plain :class:`_types.ARRAY`, unless an explicit ``type_``
|
||||
is passed.
|
||||
|
||||
"""
|
||||
kw["_default_array_type"] = ARRAY
|
||||
return functions.func.array_agg(*arg, **kw)
|
||||
|
||||
|
||||
class _regconfig_fn(functions.GenericFunction[_T]):
|
||||
inherit_cache = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
args = list(args)
|
||||
if len(args) > 1:
|
||||
initial_arg = coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
args.pop(0),
|
||||
name=getattr(self, "name", None),
|
||||
apply_propagate_attrs=self,
|
||||
type_=types.REGCONFIG,
|
||||
)
|
||||
initial_arg = [initial_arg]
|
||||
else:
|
||||
initial_arg = []
|
||||
|
||||
addtl_args = [
|
||||
coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
c,
|
||||
name=getattr(self, "name", None),
|
||||
apply_propagate_attrs=self,
|
||||
)
|
||||
for c in args
|
||||
]
|
||||
super().__init__(*(initial_arg + addtl_args), **kwargs)
|
||||
|
||||
|
||||
class to_tsvector(_regconfig_fn):
|
||||
"""The PostgreSQL ``to_tsvector`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_postgresql.TSVECTOR`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.to_tsvector` will be used automatically when invoking
|
||||
``sqlalchemy.func.to_tsvector()``, ensuring the correct argument and return
|
||||
type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = types.TSVECTOR
|
||||
|
||||
|
||||
class to_tsquery(_regconfig_fn):
|
||||
"""The PostgreSQL ``to_tsquery`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_postgresql.TSQUERY`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.to_tsquery` will be used automatically when invoking
|
||||
``sqlalchemy.func.to_tsquery()``, ensuring the correct argument and return
|
||||
type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = types.TSQUERY
|
||||
|
||||
|
||||
class plainto_tsquery(_regconfig_fn):
|
||||
"""The PostgreSQL ``plainto_tsquery`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_postgresql.TSQUERY`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.plainto_tsquery` will be used automatically when
|
||||
invoking ``sqlalchemy.func.plainto_tsquery()``, ensuring the correct
|
||||
argument and return type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = types.TSQUERY
|
||||
|
||||
|
||||
class phraseto_tsquery(_regconfig_fn):
|
||||
"""The PostgreSQL ``phraseto_tsquery`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_postgresql.TSQUERY`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.phraseto_tsquery` will be used automatically when
|
||||
invoking ``sqlalchemy.func.phraseto_tsquery()``, ensuring the correct
|
||||
argument and return type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = types.TSQUERY
|
||||
|
||||
|
||||
class websearch_to_tsquery(_regconfig_fn):
|
||||
"""The PostgreSQL ``websearch_to_tsquery`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_postgresql.TSQUERY`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.websearch_to_tsquery` will be used automatically when
|
||||
invoking ``sqlalchemy.func.websearch_to_tsquery()``, ensuring the correct
|
||||
argument and return type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = types.TSQUERY
|
||||
|
||||
|
||||
class ts_headline(_regconfig_fn):
|
||||
"""The PostgreSQL ``ts_headline`` SQL function.
|
||||
|
||||
This function applies automatic casting of the REGCONFIG argument
|
||||
to use the :class:`_postgresql.REGCONFIG` datatype automatically,
|
||||
and applies a return type of :class:`_types.TEXT`.
|
||||
|
||||
Assuming the PostgreSQL dialect has been imported, either by invoking
|
||||
``from sqlalchemy.dialects import postgresql``, or by creating a PostgreSQL
|
||||
engine using ``create_engine("postgresql...")``,
|
||||
:class:`_postgresql.ts_headline` will be used automatically when invoking
|
||||
``sqlalchemy.func.ts_headline()``, ensuring the correct argument and return
|
||||
type handlers are used at compile and execution time.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
inherit_cache = True
|
||||
type = TEXT
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
args = list(args)
|
||||
|
||||
# parse types according to
|
||||
# https://www.postgresql.org/docs/current/textsearch-controls.html#TEXTSEARCH-HEADLINE
|
||||
if len(args) < 2:
|
||||
# invalid args; don't do anything
|
||||
has_regconfig = False
|
||||
elif (
|
||||
isinstance(args[1], elements.ColumnElement)
|
||||
and args[1].type._type_affinity is types.TSQUERY
|
||||
):
|
||||
# tsquery is second argument, no regconfig argument
|
||||
has_regconfig = False
|
||||
else:
|
||||
has_regconfig = True
|
||||
|
||||
if has_regconfig:
|
||||
initial_arg = coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
args.pop(0),
|
||||
apply_propagate_attrs=self,
|
||||
name=getattr(self, "name", None),
|
||||
type_=types.REGCONFIG,
|
||||
)
|
||||
initial_arg = [initial_arg]
|
||||
else:
|
||||
initial_arg = []
|
||||
|
||||
addtl_args = [
|
||||
coercions.expect(
|
||||
roles.ExpressionElementRole,
|
||||
c,
|
||||
name=getattr(self, "name", None),
|
||||
apply_propagate_attrs=self,
|
||||
)
|
||||
for c in args
|
||||
]
|
||||
super().__init__(*(initial_arg + addtl_args), **kwargs)
|
|
@ -0,0 +1,397 @@
|
|||
# dialects/postgresql/hstore.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
import re
|
||||
|
||||
from .array import ARRAY
|
||||
from .operators import CONTAINED_BY
|
||||
from .operators import CONTAINS
|
||||
from .operators import GETITEM
|
||||
from .operators import HAS_ALL
|
||||
from .operators import HAS_ANY
|
||||
from .operators import HAS_KEY
|
||||
from ... import types as sqltypes
|
||||
from ...sql import functions as sqlfunc
|
||||
|
||||
|
||||
__all__ = ("HSTORE", "hstore")
|
||||
|
||||
|
||||
class HSTORE(sqltypes.Indexable, sqltypes.Concatenable, sqltypes.TypeEngine):
|
||||
"""Represent the PostgreSQL HSTORE type.
|
||||
|
||||
The :class:`.HSTORE` type stores dictionaries containing strings, e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', HSTORE)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
:class:`.HSTORE` provides for a wide range of operations, including:
|
||||
|
||||
* Index operations::
|
||||
|
||||
data_table.c.data['some key'] == 'some value'
|
||||
|
||||
* Containment operations::
|
||||
|
||||
data_table.c.data.has_key('some key')
|
||||
|
||||
data_table.c.data.has_all(['one', 'two', 'three'])
|
||||
|
||||
* Concatenation::
|
||||
|
||||
data_table.c.data + {"k1": "v1"}
|
||||
|
||||
For a full list of special methods see
|
||||
:class:`.HSTORE.comparator_factory`.
|
||||
|
||||
.. container:: topic
|
||||
|
||||
**Detecting Changes in HSTORE columns when using the ORM**
|
||||
|
||||
For usage with the SQLAlchemy ORM, it may be desirable to combine the
|
||||
usage of :class:`.HSTORE` with :class:`.MutableDict` dictionary now
|
||||
part of the :mod:`sqlalchemy.ext.mutable` extension. This extension
|
||||
will allow "in-place" changes to the dictionary, e.g. addition of new
|
||||
keys or replacement/removal of existing keys to/from the current
|
||||
dictionary, to produce events which will be detected by the unit of
|
||||
work::
|
||||
|
||||
from sqlalchemy.ext.mutable import MutableDict
|
||||
|
||||
class MyClass(Base):
|
||||
__tablename__ = 'data_table'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
data = Column(MutableDict.as_mutable(HSTORE))
|
||||
|
||||
my_object = session.query(MyClass).one()
|
||||
|
||||
# in-place mutation, requires Mutable extension
|
||||
# in order for the ORM to detect
|
||||
my_object.data['some_key'] = 'some value'
|
||||
|
||||
session.commit()
|
||||
|
||||
When the :mod:`sqlalchemy.ext.mutable` extension is not used, the ORM
|
||||
will not be alerted to any changes to the contents of an existing
|
||||
dictionary, unless that dictionary value is re-assigned to the
|
||||
HSTORE-attribute itself, thus generating a change event.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.hstore` - render the PostgreSQL ``hstore()`` function.
|
||||
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "HSTORE"
|
||||
hashable = False
|
||||
text_type = sqltypes.Text()
|
||||
|
||||
def __init__(self, text_type=None):
|
||||
"""Construct a new :class:`.HSTORE`.
|
||||
|
||||
:param text_type: the type that should be used for indexed values.
|
||||
Defaults to :class:`_types.Text`.
|
||||
|
||||
"""
|
||||
if text_type is not None:
|
||||
self.text_type = text_type
|
||||
|
||||
class Comparator(
|
||||
sqltypes.Indexable.Comparator, sqltypes.Concatenable.Comparator
|
||||
):
|
||||
"""Define comparison operations for :class:`.HSTORE`."""
|
||||
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def _setup_getitem(self, index):
|
||||
return GETITEM, index, self.type.text_type
|
||||
|
||||
def defined(self, key):
|
||||
"""Boolean expression. Test for presence of a non-NULL value for
|
||||
the key. Note that the key may be a SQLA expression.
|
||||
"""
|
||||
return _HStoreDefinedFunction(self.expr, key)
|
||||
|
||||
def delete(self, key):
|
||||
"""HStore expression. Returns the contents of this hstore with the
|
||||
given key deleted. Note that the key may be a SQLA expression.
|
||||
"""
|
||||
if isinstance(key, dict):
|
||||
key = _serialize_hstore(key)
|
||||
return _HStoreDeleteFunction(self.expr, key)
|
||||
|
||||
def slice(self, array):
|
||||
"""HStore expression. Returns a subset of an hstore defined by
|
||||
array of keys.
|
||||
"""
|
||||
return _HStoreSliceFunction(self.expr, array)
|
||||
|
||||
def keys(self):
|
||||
"""Text array expression. Returns array of keys."""
|
||||
return _HStoreKeysFunction(self.expr)
|
||||
|
||||
def vals(self):
|
||||
"""Text array expression. Returns array of values."""
|
||||
return _HStoreValsFunction(self.expr)
|
||||
|
||||
def array(self):
|
||||
"""Text array expression. Returns array of alternating keys and
|
||||
values.
|
||||
"""
|
||||
return _HStoreArrayFunction(self.expr)
|
||||
|
||||
def matrix(self):
|
||||
"""Text array expression. Returns array of [key, value] pairs."""
|
||||
return _HStoreMatrixFunction(self.expr)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if isinstance(value, dict):
|
||||
return _serialize_hstore(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is not None:
|
||||
return _parse_hstore(value)
|
||||
else:
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class hstore(sqlfunc.GenericFunction):
|
||||
"""Construct an hstore value within a SQL expression using the
|
||||
PostgreSQL ``hstore()`` function.
|
||||
|
||||
The :class:`.hstore` function accepts one or two arguments as described
|
||||
in the PostgreSQL documentation.
|
||||
|
||||
E.g.::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import array, hstore
|
||||
|
||||
select(hstore('key1', 'value1'))
|
||||
|
||||
select(
|
||||
hstore(
|
||||
array(['key1', 'key2', 'key3']),
|
||||
array(['value1', 'value2', 'value3'])
|
||||
)
|
||||
)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`.HSTORE` - the PostgreSQL ``HSTORE`` datatype.
|
||||
|
||||
"""
|
||||
|
||||
type = HSTORE
|
||||
name = "hstore"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreDefinedFunction(sqlfunc.GenericFunction):
|
||||
type = sqltypes.Boolean
|
||||
name = "defined"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreDeleteFunction(sqlfunc.GenericFunction):
|
||||
type = HSTORE
|
||||
name = "delete"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreSliceFunction(sqlfunc.GenericFunction):
|
||||
type = HSTORE
|
||||
name = "slice"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreKeysFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = "akeys"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreValsFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = "avals"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreArrayFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = "hstore_to_array"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
class _HStoreMatrixFunction(sqlfunc.GenericFunction):
|
||||
type = ARRAY(sqltypes.Text)
|
||||
name = "hstore_to_matrix"
|
||||
inherit_cache = True
|
||||
|
||||
|
||||
#
|
||||
# parsing. note that none of this is used with the psycopg2 backend,
|
||||
# which provides its own native extensions.
|
||||
#
|
||||
|
||||
# My best guess at the parsing rules of hstore literals, since no formal
|
||||
# grammar is given. This is mostly reverse engineered from PG's input parser
|
||||
# behavior.
|
||||
HSTORE_PAIR_RE = re.compile(
|
||||
r"""
|
||||
(
|
||||
"(?P<key> (\\ . | [^"])* )" # Quoted key
|
||||
)
|
||||
[ ]* => [ ]* # Pair operator, optional adjoining whitespace
|
||||
(
|
||||
(?P<value_null> NULL ) # NULL value
|
||||
| "(?P<value> (\\ . | [^"])* )" # Quoted value
|
||||
)
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
HSTORE_DELIMITER_RE = re.compile(
|
||||
r"""
|
||||
[ ]* , [ ]*
|
||||
""",
|
||||
re.VERBOSE,
|
||||
)
|
||||
|
||||
|
||||
def _parse_error(hstore_str, pos):
|
||||
"""format an unmarshalling error."""
|
||||
|
||||
ctx = 20
|
||||
hslen = len(hstore_str)
|
||||
|
||||
parsed_tail = hstore_str[max(pos - ctx - 1, 0) : min(pos, hslen)]
|
||||
residual = hstore_str[min(pos, hslen) : min(pos + ctx + 1, hslen)]
|
||||
|
||||
if len(parsed_tail) > ctx:
|
||||
parsed_tail = "[...]" + parsed_tail[1:]
|
||||
if len(residual) > ctx:
|
||||
residual = residual[:-1] + "[...]"
|
||||
|
||||
return "After %r, could not parse residual at position %d: %r" % (
|
||||
parsed_tail,
|
||||
pos,
|
||||
residual,
|
||||
)
|
||||
|
||||
|
||||
def _parse_hstore(hstore_str):
|
||||
"""Parse an hstore from its literal string representation.
|
||||
|
||||
Attempts to approximate PG's hstore input parsing rules as closely as
|
||||
possible. Although currently this is not strictly necessary, since the
|
||||
current implementation of hstore's output syntax is stricter than what it
|
||||
accepts as input, the documentation makes no guarantees that will always
|
||||
be the case.
|
||||
|
||||
|
||||
|
||||
"""
|
||||
result = {}
|
||||
pos = 0
|
||||
pair_match = HSTORE_PAIR_RE.match(hstore_str)
|
||||
|
||||
while pair_match is not None:
|
||||
key = pair_match.group("key").replace(r"\"", '"').replace("\\\\", "\\")
|
||||
if pair_match.group("value_null"):
|
||||
value = None
|
||||
else:
|
||||
value = (
|
||||
pair_match.group("value")
|
||||
.replace(r"\"", '"')
|
||||
.replace("\\\\", "\\")
|
||||
)
|
||||
result[key] = value
|
||||
|
||||
pos += pair_match.end()
|
||||
|
||||
delim_match = HSTORE_DELIMITER_RE.match(hstore_str[pos:])
|
||||
if delim_match is not None:
|
||||
pos += delim_match.end()
|
||||
|
||||
pair_match = HSTORE_PAIR_RE.match(hstore_str[pos:])
|
||||
|
||||
if pos != len(hstore_str):
|
||||
raise ValueError(_parse_error(hstore_str, pos))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _serialize_hstore(val):
|
||||
"""Serialize a dictionary into an hstore literal. Keys and values must
|
||||
both be strings (except None for values).
|
||||
|
||||
"""
|
||||
|
||||
def esc(s, position):
|
||||
if position == "value" and s is None:
|
||||
return "NULL"
|
||||
elif isinstance(s, str):
|
||||
return '"%s"' % s.replace("\\", "\\\\").replace('"', r"\"")
|
||||
else:
|
||||
raise ValueError(
|
||||
"%r in %s position is not a string." % (s, position)
|
||||
)
|
||||
|
||||
return ", ".join(
|
||||
"%s=>%s" % (esc(k, "key"), esc(v, "value")) for k, v in val.items()
|
||||
)
|
|
@ -0,0 +1,325 @@
|
|||
# dialects/postgresql/json.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from .array import ARRAY
|
||||
from .array import array as _pg_array
|
||||
from .operators import ASTEXT
|
||||
from .operators import CONTAINED_BY
|
||||
from .operators import CONTAINS
|
||||
from .operators import DELETE_PATH
|
||||
from .operators import HAS_ALL
|
||||
from .operators import HAS_ANY
|
||||
from .operators import HAS_KEY
|
||||
from .operators import JSONPATH_ASTEXT
|
||||
from .operators import PATH_EXISTS
|
||||
from .operators import PATH_MATCH
|
||||
from ... import types as sqltypes
|
||||
from ...sql import cast
|
||||
|
||||
__all__ = ("JSON", "JSONB")
|
||||
|
||||
|
||||
class JSONPathType(sqltypes.JSON.JSONPathType):
|
||||
def _processor(self, dialect, super_proc):
|
||||
def process(value):
|
||||
if isinstance(value, str):
|
||||
# If it's already a string assume that it's in json path
|
||||
# format. This allows using cast with json paths literals
|
||||
return value
|
||||
elif value:
|
||||
# If it's already a string assume that it's in json path
|
||||
# format. This allows using cast with json paths literals
|
||||
value = "{%s}" % (", ".join(map(str, value)))
|
||||
else:
|
||||
value = "{}"
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_bind_processor(dialect))
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
return self._processor(dialect, self.string_literal_processor(dialect))
|
||||
|
||||
|
||||
class JSONPATH(JSONPathType):
|
||||
"""JSON Path Type.
|
||||
|
||||
This is usually required to cast literal values to json path when using
|
||||
json search like function, such as ``jsonb_path_query_array`` or
|
||||
``jsonb_path_exists``::
|
||||
|
||||
stmt = sa.select(
|
||||
sa.func.jsonb_path_query_array(
|
||||
table.c.jsonb_col, cast("$.address.id", JSONPATH)
|
||||
)
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "JSONPATH"
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""Represent the PostgreSQL JSON type.
|
||||
|
||||
:class:`_postgresql.JSON` is used automatically whenever the base
|
||||
:class:`_types.JSON` datatype is used against a PostgreSQL backend,
|
||||
however base :class:`_types.JSON` datatype does not provide Python
|
||||
accessors for PostgreSQL-specific comparison methods such as
|
||||
:meth:`_postgresql.JSON.Comparator.astext`; additionally, to use
|
||||
PostgreSQL ``JSONB``, the :class:`_postgresql.JSONB` datatype should
|
||||
be used explicitly.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON` - main documentation for the generic
|
||||
cross-platform JSON datatype.
|
||||
|
||||
The operators provided by the PostgreSQL version of :class:`_types.JSON`
|
||||
include:
|
||||
|
||||
* Index operations (the ``->`` operator)::
|
||||
|
||||
data_table.c.data['some key']
|
||||
|
||||
data_table.c.data[5]
|
||||
|
||||
|
||||
* Index operations returning text (the ``->>`` operator)::
|
||||
|
||||
data_table.c.data['some key'].astext == 'some value'
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_string` accessor.
|
||||
|
||||
* Index operations with CAST
|
||||
(equivalent to ``CAST(col ->> ['some key'] AS <type>)``)::
|
||||
|
||||
data_table.c.data['some key'].astext.cast(Integer) == 5
|
||||
|
||||
Note that equivalent functionality is available via the
|
||||
:attr:`.JSON.Comparator.as_integer` and similar accessors.
|
||||
|
||||
* Path index operations (the ``#>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')]
|
||||
|
||||
* Path index operations returning text (the ``#>>`` operator)::
|
||||
|
||||
data_table.c.data[('key_1', 'key_2', 5, ..., 'key_n')].astext == 'some value'
|
||||
|
||||
Index operations return an expression object whose type defaults to
|
||||
:class:`_types.JSON` by default,
|
||||
so that further JSON-oriented instructions
|
||||
may be called upon the result type.
|
||||
|
||||
Custom serializers and deserializers are specified at the dialect level,
|
||||
that is using :func:`_sa.create_engine`. The reason for this is that when
|
||||
using psycopg2, the DBAPI only allows serializers at the per-cursor
|
||||
or per-connection level. E.g.::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
json_serializer=my_serialize_fn,
|
||||
json_deserializer=my_deserialize_fn
|
||||
)
|
||||
|
||||
When using the psycopg2 dialect, the json_deserializer is registered
|
||||
against the database using ``psycopg2.extras.register_default_json``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON` - Core level JSON type
|
||||
|
||||
:class:`_postgresql.JSONB`
|
||||
|
||||
""" # noqa
|
||||
|
||||
astext_type = sqltypes.Text()
|
||||
|
||||
def __init__(self, none_as_null=False, astext_type=None):
|
||||
"""Construct a :class:`_types.JSON` type.
|
||||
|
||||
:param none_as_null: if True, persist the value ``None`` as a
|
||||
SQL NULL value, not the JSON encoding of ``null``. Note that
|
||||
when this flag is False, the :func:`.null` construct can still
|
||||
be used to persist a NULL value::
|
||||
|
||||
from sqlalchemy import null
|
||||
conn.execute(table.insert(), data=null())
|
||||
|
||||
.. seealso::
|
||||
|
||||
:attr:`_types.JSON.NULL`
|
||||
|
||||
:param astext_type: the type to use for the
|
||||
:attr:`.JSON.Comparator.astext`
|
||||
accessor on indexed attributes. Defaults to :class:`_types.Text`.
|
||||
|
||||
"""
|
||||
super().__init__(none_as_null=none_as_null)
|
||||
if astext_type is not None:
|
||||
self.astext_type = astext_type
|
||||
|
||||
class Comparator(sqltypes.JSON.Comparator):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
@property
|
||||
def astext(self):
|
||||
"""On an indexed expression, use the "astext" (e.g. "->>")
|
||||
conversion when rendered in SQL.
|
||||
|
||||
E.g.::
|
||||
|
||||
select(data_table.c.data['some key'].astext)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:meth:`_expression.ColumnElement.cast`
|
||||
|
||||
"""
|
||||
if isinstance(self.expr.right.type, sqltypes.JSON.JSONPathType):
|
||||
return self.expr.left.operate(
|
||||
JSONPATH_ASTEXT,
|
||||
self.expr.right,
|
||||
result_type=self.type.astext_type,
|
||||
)
|
||||
else:
|
||||
return self.expr.left.operate(
|
||||
ASTEXT, self.expr.right, result_type=self.type.astext_type
|
||||
)
|
||||
|
||||
comparator_factory = Comparator
|
||||
|
||||
|
||||
class JSONB(JSON):
|
||||
"""Represent the PostgreSQL JSONB type.
|
||||
|
||||
The :class:`_postgresql.JSONB` type stores arbitrary JSONB format data,
|
||||
e.g.::
|
||||
|
||||
data_table = Table('data_table', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('data', JSONB)
|
||||
)
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(
|
||||
data_table.insert(),
|
||||
data = {"key1": "value1", "key2": "value2"}
|
||||
)
|
||||
|
||||
The :class:`_postgresql.JSONB` type includes all operations provided by
|
||||
:class:`_types.JSON`, including the same behaviors for indexing
|
||||
operations.
|
||||
It also adds additional operators specific to JSONB, including
|
||||
:meth:`.JSONB.Comparator.has_key`, :meth:`.JSONB.Comparator.has_all`,
|
||||
:meth:`.JSONB.Comparator.has_any`, :meth:`.JSONB.Comparator.contains`,
|
||||
:meth:`.JSONB.Comparator.contained_by`,
|
||||
:meth:`.JSONB.Comparator.delete_path`,
|
||||
:meth:`.JSONB.Comparator.path_exists` and
|
||||
:meth:`.JSONB.Comparator.path_match`.
|
||||
|
||||
Like the :class:`_types.JSON` type, the :class:`_postgresql.JSONB`
|
||||
type does not detect
|
||||
in-place changes when used with the ORM, unless the
|
||||
:mod:`sqlalchemy.ext.mutable` extension is used.
|
||||
|
||||
Custom serializers and deserializers
|
||||
are shared with the :class:`_types.JSON` class,
|
||||
using the ``json_serializer``
|
||||
and ``json_deserializer`` keyword arguments. These must be specified
|
||||
at the dialect level using :func:`_sa.create_engine`. When using
|
||||
psycopg2, the serializers are associated with the jsonb type using
|
||||
``psycopg2.extras.register_default_jsonb`` on a per-connection basis,
|
||||
in the same way that ``psycopg2.extras.register_default_json`` is used
|
||||
to register these handlers with the json type.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "JSONB"
|
||||
|
||||
class Comparator(JSON.Comparator):
|
||||
"""Define comparison operations for :class:`_types.JSON`."""
|
||||
|
||||
def has_key(self, other):
|
||||
"""Boolean expression. Test for presence of a key. Note that the
|
||||
key may be a SQLA expression.
|
||||
"""
|
||||
return self.operate(HAS_KEY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_all(self, other):
|
||||
"""Boolean expression. Test for presence of all keys in jsonb"""
|
||||
return self.operate(HAS_ALL, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def has_any(self, other):
|
||||
"""Boolean expression. Test for presence of any key in jsonb"""
|
||||
return self.operate(HAS_ANY, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
"""Boolean expression. Test if keys (or array) are a superset
|
||||
of/contained the keys of the argument jsonb expression.
|
||||
|
||||
kwargs may be ignored by this operator but are required for API
|
||||
conformance.
|
||||
"""
|
||||
return self.operate(CONTAINS, other, result_type=sqltypes.Boolean)
|
||||
|
||||
def contained_by(self, other):
|
||||
"""Boolean expression. Test if keys are a proper subset of the
|
||||
keys of the argument jsonb expression.
|
||||
"""
|
||||
return self.operate(
|
||||
CONTAINED_BY, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def delete_path(self, array):
|
||||
"""JSONB expression. Deletes field or array element specified in
|
||||
the argument array.
|
||||
|
||||
The input may be a list of strings that will be coerced to an
|
||||
``ARRAY`` or an instance of :meth:`_postgres.array`.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if not isinstance(array, _pg_array):
|
||||
array = _pg_array(array)
|
||||
right_side = cast(array, ARRAY(sqltypes.TEXT))
|
||||
return self.operate(DELETE_PATH, right_side, result_type=JSONB)
|
||||
|
||||
def path_exists(self, other):
|
||||
"""Boolean expression. Test for presence of item given by the
|
||||
argument JSONPath expression.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.operate(
|
||||
PATH_EXISTS, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
def path_match(self, other):
|
||||
"""Boolean expression. Test if JSONPath predicate given by the
|
||||
argument JSONPath expression matches.
|
||||
|
||||
Only the first item of the result is taken into account.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self.operate(
|
||||
PATH_MATCH, other, result_type=sqltypes.Boolean
|
||||
)
|
||||
|
||||
comparator_factory = Comparator
|
|
@ -0,0 +1,495 @@
|
|||
# dialects/postgresql/named_types.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from ... import schema
|
||||
from ... import util
|
||||
from ...sql import coercions
|
||||
from ...sql import elements
|
||||
from ...sql import roles
|
||||
from ...sql import sqltypes
|
||||
from ...sql import type_api
|
||||
from ...sql.base import _NoArg
|
||||
from ...sql.ddl import InvokeCreateDDLBase
|
||||
from ...sql.ddl import InvokeDropDDLBase
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...sql._typing import _TypeEngineArgument
|
||||
|
||||
|
||||
class NamedType(sqltypes.TypeEngine):
|
||||
"""Base for named types."""
|
||||
|
||||
__abstract__ = True
|
||||
DDLGenerator: Type[NamedTypeGenerator]
|
||||
DDLDropper: Type[NamedTypeDropper]
|
||||
create_type: bool
|
||||
|
||||
def create(self, bind, checkfirst=True, **kw):
|
||||
"""Emit ``CREATE`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
:class:`_engine.Connection`, or similar object to emit
|
||||
SQL.
|
||||
:param checkfirst: if ``True``, a query against
|
||||
the PG catalog will be first performed to see
|
||||
if the type does not exist already before
|
||||
creating.
|
||||
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLGenerator, self, checkfirst=checkfirst)
|
||||
|
||||
def drop(self, bind, checkfirst=True, **kw):
|
||||
"""Emit ``DROP`` DDL for this type.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
:class:`_engine.Connection`, or similar object to emit
|
||||
SQL.
|
||||
:param checkfirst: if ``True``, a query against
|
||||
the PG catalog will be first performed to see
|
||||
if the type actually exists before dropping.
|
||||
|
||||
"""
|
||||
bind._run_ddl_visitor(self.DDLDropper, self, checkfirst=checkfirst)
|
||||
|
||||
def _check_for_name_in_memos(self, checkfirst, kw):
|
||||
"""Look in the 'ddl runner' for 'memos', then
|
||||
note our name in that collection.
|
||||
|
||||
This to ensure a particular named type is operated
|
||||
upon only once within any kind of create/drop
|
||||
sequence without relying upon "checkfirst".
|
||||
|
||||
"""
|
||||
if not self.create_type:
|
||||
return True
|
||||
if "_ddl_runner" in kw:
|
||||
ddl_runner = kw["_ddl_runner"]
|
||||
type_name = f"pg_{self.__visit_name__}"
|
||||
if type_name in ddl_runner.memo:
|
||||
existing = ddl_runner.memo[type_name]
|
||||
else:
|
||||
existing = ddl_runner.memo[type_name] = set()
|
||||
present = (self.schema, self.name) in existing
|
||||
existing.add((self.schema, self.name))
|
||||
return present
|
||||
else:
|
||||
return False
|
||||
|
||||
def _on_table_create(self, target, bind, checkfirst=False, **kw):
|
||||
if (
|
||||
checkfirst
|
||||
or (
|
||||
not self.metadata
|
||||
and not kw.get("_is_metadata_operation", False)
|
||||
)
|
||||
) and not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_table_drop(self, target, bind, checkfirst=False, **kw):
|
||||
if (
|
||||
not self.metadata
|
||||
and not kw.get("_is_metadata_operation", False)
|
||||
and not self._check_for_name_in_memos(checkfirst, kw)
|
||||
):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_create(self, target, bind, checkfirst=False, **kw):
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.create(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
def _on_metadata_drop(self, target, bind, checkfirst=False, **kw):
|
||||
if not self._check_for_name_in_memos(checkfirst, kw):
|
||||
self.drop(bind=bind, checkfirst=checkfirst)
|
||||
|
||||
|
||||
class NamedTypeGenerator(InvokeCreateDDLBase):
|
||||
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
|
||||
super().__init__(connection, **kwargs)
|
||||
self.checkfirst = checkfirst
|
||||
|
||||
def _can_create_type(self, type_):
|
||||
if not self.checkfirst:
|
||||
return True
|
||||
|
||||
effective_schema = self.connection.schema_for_object(type_)
|
||||
return not self.connection.dialect.has_type(
|
||||
self.connection, type_.name, schema=effective_schema
|
||||
)
|
||||
|
||||
|
||||
class NamedTypeDropper(InvokeDropDDLBase):
|
||||
def __init__(self, dialect, connection, checkfirst=False, **kwargs):
|
||||
super().__init__(connection, **kwargs)
|
||||
self.checkfirst = checkfirst
|
||||
|
||||
def _can_drop_type(self, type_):
|
||||
if not self.checkfirst:
|
||||
return True
|
||||
|
||||
effective_schema = self.connection.schema_for_object(type_)
|
||||
return self.connection.dialect.has_type(
|
||||
self.connection, type_.name, schema=effective_schema
|
||||
)
|
||||
|
||||
|
||||
class EnumGenerator(NamedTypeGenerator):
|
||||
def visit_enum(self, enum):
|
||||
if not self._can_create_type(enum):
|
||||
return
|
||||
|
||||
with self.with_ddl_events(enum):
|
||||
self.connection.execute(CreateEnumType(enum))
|
||||
|
||||
|
||||
class EnumDropper(NamedTypeDropper):
|
||||
def visit_enum(self, enum):
|
||||
if not self._can_drop_type(enum):
|
||||
return
|
||||
|
||||
with self.with_ddl_events(enum):
|
||||
self.connection.execute(DropEnumType(enum))
|
||||
|
||||
|
||||
class ENUM(NamedType, type_api.NativeForEmulated, sqltypes.Enum):
|
||||
"""PostgreSQL ENUM type.
|
||||
|
||||
This is a subclass of :class:`_types.Enum` which includes
|
||||
support for PG's ``CREATE TYPE`` and ``DROP TYPE``.
|
||||
|
||||
When the builtin type :class:`_types.Enum` is used and the
|
||||
:paramref:`.Enum.native_enum` flag is left at its default of
|
||||
True, the PostgreSQL backend will use a :class:`_postgresql.ENUM`
|
||||
type as the implementation, so the special create/drop rules
|
||||
will be used.
|
||||
|
||||
The create/drop behavior of ENUM is necessarily intricate, due to the
|
||||
awkward relationship the ENUM type has in relationship to the
|
||||
parent table, in that it may be "owned" by just a single table, or
|
||||
may be shared among many tables.
|
||||
|
||||
When using :class:`_types.Enum` or :class:`_postgresql.ENUM`
|
||||
in an "inline" fashion, the ``CREATE TYPE`` and ``DROP TYPE`` is emitted
|
||||
corresponding to when the :meth:`_schema.Table.create` and
|
||||
:meth:`_schema.Table.drop`
|
||||
methods are called::
|
||||
|
||||
table = Table('sometable', metadata,
|
||||
Column('some_enum', ENUM('a', 'b', 'c', name='myenum'))
|
||||
)
|
||||
|
||||
table.create(engine) # will emit CREATE ENUM and CREATE TABLE
|
||||
table.drop(engine) # will emit DROP TABLE and DROP ENUM
|
||||
|
||||
To use a common enumerated type between multiple tables, the best
|
||||
practice is to declare the :class:`_types.Enum` or
|
||||
:class:`_postgresql.ENUM` independently, and associate it with the
|
||||
:class:`_schema.MetaData` object itself::
|
||||
|
||||
my_enum = ENUM('a', 'b', 'c', name='myenum', metadata=metadata)
|
||||
|
||||
t1 = Table('sometable_one', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
|
||||
t2 = Table('sometable_two', metadata,
|
||||
Column('some_enum', myenum)
|
||||
)
|
||||
|
||||
When this pattern is used, care must still be taken at the level
|
||||
of individual table creates. Emitting CREATE TABLE without also
|
||||
specifying ``checkfirst=True`` will still cause issues::
|
||||
|
||||
t1.create(engine) # will fail: no such type 'myenum'
|
||||
|
||||
If we specify ``checkfirst=True``, the individual table-level create
|
||||
operation will check for the ``ENUM`` and create if not exists::
|
||||
|
||||
# will check if enum exists, and emit CREATE TYPE if not
|
||||
t1.create(engine, checkfirst=True)
|
||||
|
||||
When using a metadata-level ENUM type, the type will always be created
|
||||
and dropped if either the metadata-wide create/drop is called::
|
||||
|
||||
metadata.create_all(engine) # will emit CREATE TYPE
|
||||
metadata.drop_all(engine) # will emit DROP TYPE
|
||||
|
||||
The type can also be created and dropped directly::
|
||||
|
||||
my_enum.create(engine)
|
||||
my_enum.drop(engine)
|
||||
|
||||
"""
|
||||
|
||||
native_enum = True
|
||||
DDLGenerator = EnumGenerator
|
||||
DDLDropper = EnumDropper
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*enums,
|
||||
name: Union[str, _NoArg, None] = _NoArg.NO_ARG,
|
||||
create_type: bool = True,
|
||||
**kw,
|
||||
):
|
||||
"""Construct an :class:`_postgresql.ENUM`.
|
||||
|
||||
Arguments are the same as that of
|
||||
:class:`_types.Enum`, but also including
|
||||
the following parameters.
|
||||
|
||||
:param create_type: Defaults to True.
|
||||
Indicates that ``CREATE TYPE`` should be
|
||||
emitted, after optionally checking for the
|
||||
presence of the type, when the parent
|
||||
table is being created; and additionally
|
||||
that ``DROP TYPE`` is called when the table
|
||||
is dropped. When ``False``, no check
|
||||
will be performed and no ``CREATE TYPE``
|
||||
or ``DROP TYPE`` is emitted, unless
|
||||
:meth:`~.postgresql.ENUM.create`
|
||||
or :meth:`~.postgresql.ENUM.drop`
|
||||
are called directly.
|
||||
Setting to ``False`` is helpful
|
||||
when invoking a creation scheme to a SQL file
|
||||
without access to the actual database -
|
||||
the :meth:`~.postgresql.ENUM.create` and
|
||||
:meth:`~.postgresql.ENUM.drop` methods can
|
||||
be used to emit SQL to a target bind.
|
||||
|
||||
"""
|
||||
native_enum = kw.pop("native_enum", None)
|
||||
if native_enum is False:
|
||||
util.warn(
|
||||
"the native_enum flag does not apply to the "
|
||||
"sqlalchemy.dialects.postgresql.ENUM datatype; this type "
|
||||
"always refers to ENUM. Use sqlalchemy.types.Enum for "
|
||||
"non-native enum."
|
||||
)
|
||||
self.create_type = create_type
|
||||
if name is not _NoArg.NO_ARG:
|
||||
kw["name"] = name
|
||||
super().__init__(*enums, **kw)
|
||||
|
||||
def coerce_compared_value(self, op, value):
|
||||
super_coerced_type = super().coerce_compared_value(op, value)
|
||||
if (
|
||||
super_coerced_type._type_affinity
|
||||
is type_api.STRINGTYPE._type_affinity
|
||||
):
|
||||
return self
|
||||
else:
|
||||
return super_coerced_type
|
||||
|
||||
@classmethod
|
||||
def __test_init__(cls):
|
||||
return cls(name="name")
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(cls, impl, **kw):
|
||||
"""Produce a PostgreSQL native :class:`_postgresql.ENUM` from plain
|
||||
:class:`.Enum`.
|
||||
|
||||
"""
|
||||
kw.setdefault("validate_strings", impl.validate_strings)
|
||||
kw.setdefault("name", impl.name)
|
||||
kw.setdefault("schema", impl.schema)
|
||||
kw.setdefault("inherit_schema", impl.inherit_schema)
|
||||
kw.setdefault("metadata", impl.metadata)
|
||||
kw.setdefault("_create_events", False)
|
||||
kw.setdefault("values_callable", impl.values_callable)
|
||||
kw.setdefault("omit_aliases", impl._omit_aliases)
|
||||
kw.setdefault("_adapted_from", impl)
|
||||
if type_api._is_native_for_emulated(impl.__class__):
|
||||
kw.setdefault("create_type", impl.create_type)
|
||||
|
||||
return cls(**kw)
|
||||
|
||||
def create(self, bind=None, checkfirst=True):
|
||||
"""Emit ``CREATE TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
If the underlying dialect does not support
|
||||
PostgreSQL CREATE TYPE, no action is taken.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
:class:`_engine.Connection`, or similar object to emit
|
||||
SQL.
|
||||
:param checkfirst: if ``True``, a query against
|
||||
the PG catalog will be first performed to see
|
||||
if the type does not exist already before
|
||||
creating.
|
||||
|
||||
"""
|
||||
if not bind.dialect.supports_native_enum:
|
||||
return
|
||||
|
||||
super().create(bind, checkfirst=checkfirst)
|
||||
|
||||
def drop(self, bind=None, checkfirst=True):
|
||||
"""Emit ``DROP TYPE`` for this
|
||||
:class:`_postgresql.ENUM`.
|
||||
|
||||
If the underlying dialect does not support
|
||||
PostgreSQL DROP TYPE, no action is taken.
|
||||
|
||||
:param bind: a connectable :class:`_engine.Engine`,
|
||||
:class:`_engine.Connection`, or similar object to emit
|
||||
SQL.
|
||||
:param checkfirst: if ``True``, a query against
|
||||
the PG catalog will be first performed to see
|
||||
if the type actually exists before dropping.
|
||||
|
||||
"""
|
||||
if not bind.dialect.supports_native_enum:
|
||||
return
|
||||
|
||||
super().drop(bind, checkfirst=checkfirst)
|
||||
|
||||
def get_dbapi_type(self, dbapi):
|
||||
"""dont return dbapi.STRING for ENUM in PostgreSQL, since that's
|
||||
a different type"""
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DomainGenerator(NamedTypeGenerator):
|
||||
def visit_DOMAIN(self, domain):
|
||||
if not self._can_create_type(domain):
|
||||
return
|
||||
with self.with_ddl_events(domain):
|
||||
self.connection.execute(CreateDomainType(domain))
|
||||
|
||||
|
||||
class DomainDropper(NamedTypeDropper):
|
||||
def visit_DOMAIN(self, domain):
|
||||
if not self._can_drop_type(domain):
|
||||
return
|
||||
|
||||
with self.with_ddl_events(domain):
|
||||
self.connection.execute(DropDomainType(domain))
|
||||
|
||||
|
||||
class DOMAIN(NamedType, sqltypes.SchemaType):
|
||||
r"""Represent the DOMAIN PostgreSQL type.
|
||||
|
||||
A domain is essentially a data type with optional constraints
|
||||
that restrict the allowed set of values. E.g.::
|
||||
|
||||
PositiveInt = DOMAIN(
|
||||
"pos_int", Integer, check="VALUE > 0", not_null=True
|
||||
)
|
||||
|
||||
UsPostalCode = DOMAIN(
|
||||
"us_postal_code",
|
||||
Text,
|
||||
check="VALUE ~ '^\d{5}$' OR VALUE ~ '^\d{5}-\d{4}$'"
|
||||
)
|
||||
|
||||
See the `PostgreSQL documentation`__ for additional details
|
||||
|
||||
__ https://www.postgresql.org/docs/current/sql-createdomain.html
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
"""
|
||||
|
||||
DDLGenerator = DomainGenerator
|
||||
DDLDropper = DomainDropper
|
||||
|
||||
__visit_name__ = "DOMAIN"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
data_type: _TypeEngineArgument[Any],
|
||||
*,
|
||||
collation: Optional[str] = None,
|
||||
default: Optional[Union[str, elements.TextClause]] = None,
|
||||
constraint_name: Optional[str] = None,
|
||||
not_null: Optional[bool] = None,
|
||||
check: Optional[str] = None,
|
||||
create_type: bool = True,
|
||||
**kw: Any,
|
||||
):
|
||||
"""
|
||||
Construct a DOMAIN.
|
||||
|
||||
:param name: the name of the domain
|
||||
:param data_type: The underlying data type of the domain.
|
||||
This can include array specifiers.
|
||||
:param collation: An optional collation for the domain.
|
||||
If no collation is specified, the underlying data type's default
|
||||
collation is used. The underlying type must be collatable if
|
||||
``collation`` is specified.
|
||||
:param default: The DEFAULT clause specifies a default value for
|
||||
columns of the domain data type. The default should be a string
|
||||
or a :func:`_expression.text` value.
|
||||
If no default value is specified, then the default value is
|
||||
the null value.
|
||||
:param constraint_name: An optional name for a constraint.
|
||||
If not specified, the backend generates a name.
|
||||
:param not_null: Values of this domain are prevented from being null.
|
||||
By default domain are allowed to be null. If not specified
|
||||
no nullability clause will be emitted.
|
||||
:param check: CHECK clause specify integrity constraint or test
|
||||
which values of the domain must satisfy. A constraint must be
|
||||
an expression producing a Boolean result that can use the key
|
||||
word VALUE to refer to the value being tested.
|
||||
Differently from PostgreSQL, only a single check clause is
|
||||
currently allowed in SQLAlchemy.
|
||||
:param schema: optional schema name
|
||||
:param metadata: optional :class:`_schema.MetaData` object which
|
||||
this :class:`_postgresql.DOMAIN` will be directly associated
|
||||
:param create_type: Defaults to True.
|
||||
Indicates that ``CREATE TYPE`` should be emitted, after optionally
|
||||
checking for the presence of the type, when the parent table is
|
||||
being created; and additionally that ``DROP TYPE`` is called
|
||||
when the table is dropped.
|
||||
|
||||
"""
|
||||
self.data_type = type_api.to_instance(data_type)
|
||||
self.default = default
|
||||
self.collation = collation
|
||||
self.constraint_name = constraint_name
|
||||
self.not_null = not_null
|
||||
if check is not None:
|
||||
check = coercions.expect(roles.DDLExpressionRole, check)
|
||||
self.check = check
|
||||
self.create_type = create_type
|
||||
super().__init__(name=name, **kw)
|
||||
|
||||
@classmethod
|
||||
def __test_init__(cls):
|
||||
return cls("name", sqltypes.Integer)
|
||||
|
||||
|
||||
class CreateEnumType(schema._CreateDropBase):
|
||||
__visit_name__ = "create_enum_type"
|
||||
|
||||
|
||||
class DropEnumType(schema._CreateDropBase):
|
||||
__visit_name__ = "drop_enum_type"
|
||||
|
||||
|
||||
class CreateDomainType(schema._CreateDropBase):
|
||||
"""Represent a CREATE DOMAIN statement."""
|
||||
|
||||
__visit_name__ = "create_domain_type"
|
||||
|
||||
|
||||
class DropDomainType(schema._CreateDropBase):
|
||||
"""Represent a DROP DOMAIN statement."""
|
||||
|
||||
__visit_name__ = "drop_domain_type"
|
|
@ -0,0 +1,129 @@
|
|||
# dialects/postgresql/operators.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
from ...sql import operators
|
||||
|
||||
|
||||
_getitem_precedence = operators._PRECEDENCE[operators.json_getitem_op]
|
||||
_eq_precedence = operators._PRECEDENCE[operators.eq]
|
||||
|
||||
# JSON + JSONB
|
||||
ASTEXT = operators.custom_op(
|
||||
"->>",
|
||||
precedence=_getitem_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
)
|
||||
|
||||
JSONPATH_ASTEXT = operators.custom_op(
|
||||
"#>>",
|
||||
precedence=_getitem_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
)
|
||||
|
||||
# JSONB + HSTORE
|
||||
HAS_KEY = operators.custom_op(
|
||||
"?",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
HAS_ALL = operators.custom_op(
|
||||
"?&",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
HAS_ANY = operators.custom_op(
|
||||
"?|",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
# JSONB
|
||||
DELETE_PATH = operators.custom_op(
|
||||
"#-",
|
||||
precedence=_getitem_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
)
|
||||
|
||||
PATH_EXISTS = operators.custom_op(
|
||||
"@?",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
PATH_MATCH = operators.custom_op(
|
||||
"@@",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
# JSONB + ARRAY + HSTORE + RANGE
|
||||
CONTAINS = operators.custom_op(
|
||||
"@>",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
CONTAINED_BY = operators.custom_op(
|
||||
"<@",
|
||||
precedence=_eq_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
# ARRAY + RANGE
|
||||
OVERLAP = operators.custom_op(
|
||||
"&&",
|
||||
precedence=_eq_precedence,
|
||||
is_comparison=True,
|
||||
)
|
||||
|
||||
# RANGE
|
||||
STRICTLY_LEFT_OF = operators.custom_op(
|
||||
"<<", precedence=_eq_precedence, is_comparison=True
|
||||
)
|
||||
|
||||
STRICTLY_RIGHT_OF = operators.custom_op(
|
||||
">>", precedence=_eq_precedence, is_comparison=True
|
||||
)
|
||||
|
||||
NOT_EXTEND_RIGHT_OF = operators.custom_op(
|
||||
"&<", precedence=_eq_precedence, is_comparison=True
|
||||
)
|
||||
|
||||
NOT_EXTEND_LEFT_OF = operators.custom_op(
|
||||
"&>", precedence=_eq_precedence, is_comparison=True
|
||||
)
|
||||
|
||||
ADJACENT_TO = operators.custom_op(
|
||||
"-|-", precedence=_eq_precedence, is_comparison=True
|
||||
)
|
||||
|
||||
# HSTORE
|
||||
GETITEM = operators.custom_op(
|
||||
"->",
|
||||
precedence=_getitem_precedence,
|
||||
natural_self_precedent=True,
|
||||
eager_grouping=True,
|
||||
)
|
|
@ -0,0 +1,662 @@
|
|||
# dialects/postgresql/pg8000.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors <see AUTHORS
|
||||
# file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: postgresql+pg8000
|
||||
:name: pg8000
|
||||
:dbapi: pg8000
|
||||
:connectstring: postgresql+pg8000://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://pypi.org/project/pg8000/
|
||||
|
||||
.. versionchanged:: 1.4 The pg8000 dialect has been updated for version
|
||||
1.16.6 and higher, and is again part of SQLAlchemy's continuous integration
|
||||
with full feature support.
|
||||
|
||||
.. _pg8000_unicode:
|
||||
|
||||
Unicode
|
||||
-------
|
||||
|
||||
pg8000 will encode / decode string values between it and the server using the
|
||||
PostgreSQL ``client_encoding`` parameter; by default this is the value in
|
||||
the ``postgresql.conf`` file, which often defaults to ``SQL_ASCII``.
|
||||
Typically, this can be changed to ``utf-8``, as a more useful default::
|
||||
|
||||
#client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
client_encoding = utf8
|
||||
|
||||
The ``client_encoding`` can be overridden for a session by executing the SQL:
|
||||
|
||||
SET CLIENT_ENCODING TO 'utf8';
|
||||
|
||||
SQLAlchemy will execute this SQL on all new connections based on the value
|
||||
passed to :func:`_sa.create_engine` using the ``client_encoding`` parameter::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+pg8000://user:pass@host/dbname", client_encoding='utf8')
|
||||
|
||||
.. _pg8000_ssl:
|
||||
|
||||
SSL Connections
|
||||
---------------
|
||||
|
||||
pg8000 accepts a Python ``SSLContext`` object which may be specified using the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary::
|
||||
|
||||
import ssl
|
||||
ssl_context = ssl.create_default_context()
|
||||
engine = sa.create_engine(
|
||||
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
|
||||
connect_args={"ssl_context": ssl_context},
|
||||
)
|
||||
|
||||
If the server uses an automatically-generated certificate that is self-signed
|
||||
or does not match the host name (as seen from the client), it may also be
|
||||
necessary to disable hostname checking::
|
||||
|
||||
import ssl
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.check_hostname = False
|
||||
ssl_context.verify_mode = ssl.CERT_NONE
|
||||
engine = sa.create_engine(
|
||||
"postgresql+pg8000://scott:tiger@192.168.0.199/test",
|
||||
connect_args={"ssl_context": ssl_context},
|
||||
)
|
||||
|
||||
.. _pg8000_isolation_level:
|
||||
|
||||
pg8000 Transaction Isolation Level
|
||||
-------------------------------------
|
||||
|
||||
The pg8000 dialect offers the same isolation level settings as that
|
||||
of the :ref:`psycopg2 <psycopg2_isolation_level>` dialect:
|
||||
|
||||
* ``READ COMMITTED``
|
||||
* ``READ UNCOMMITTED``
|
||||
* ``REPEATABLE READ``
|
||||
* ``SERIALIZABLE``
|
||||
* ``AUTOCOMMIT``
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_isolation_level`
|
||||
|
||||
:ref:`psycopg2_isolation_level`
|
||||
|
||||
|
||||
""" # noqa
|
||||
import decimal
|
||||
import re
|
||||
|
||||
from . import ranges
|
||||
from .array import ARRAY as PGARRAY
|
||||
from .base import _DECIMAL_TYPES
|
||||
from .base import _FLOAT_TYPES
|
||||
from .base import _INT_TYPES
|
||||
from .base import ENUM
|
||||
from .base import INTERVAL
|
||||
from .base import PGCompiler
|
||||
from .base import PGDialect
|
||||
from .base import PGExecutionContext
|
||||
from .base import PGIdentifierPreparer
|
||||
from .json import JSON
|
||||
from .json import JSONB
|
||||
from .json import JSONPathType
|
||||
from .pg_catalog import _SpaceVector
|
||||
from .pg_catalog import OIDVECTOR
|
||||
from .types import CITEXT
|
||||
from ... import exc
|
||||
from ... import util
|
||||
from ...engine import processors
|
||||
from ...sql import sqltypes
|
||||
from ...sql.elements import quoted_name
|
||||
|
||||
|
||||
class _PGString(sqltypes.String):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGNumeric(sqltypes.Numeric):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if self.asdecimal:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
return processors.to_decimal_processor_factory(
|
||||
decimal.Decimal, self._effective_decimal_return_scale
|
||||
)
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
# pg8000 returns Decimal natively for 1700
|
||||
return None
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype
|
||||
)
|
||||
else:
|
||||
if coltype in _FLOAT_TYPES:
|
||||
# pg8000 returns float natively for 701
|
||||
return None
|
||||
elif coltype in _DECIMAL_TYPES or coltype in _INT_TYPES:
|
||||
return processors.to_float
|
||||
else:
|
||||
raise exc.InvalidRequestError(
|
||||
"Unknown PG numeric type: %d" % coltype
|
||||
)
|
||||
|
||||
|
||||
class _PGFloat(_PGNumeric, sqltypes.Float):
|
||||
__visit_name__ = "float"
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGNumericNoBind(_PGNumeric):
|
||||
def bind_processor(self, dialect):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSONIndexType(sqltypes.JSON.JSONIndexType):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
raise NotImplementedError("should not be here")
|
||||
|
||||
|
||||
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
|
||||
__visit_name__ = "json_int_index"
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
|
||||
__visit_name__ = "json_str_index"
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGJSONPathType(JSONPathType):
|
||||
pass
|
||||
|
||||
# DBAPI type 1009
|
||||
|
||||
|
||||
class _PGEnum(ENUM):
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.UNKNOWN
|
||||
|
||||
|
||||
class _PGInterval(INTERVAL):
|
||||
render_bind_cast = True
|
||||
|
||||
def get_dbapi_type(self, dbapi):
|
||||
return dbapi.INTERVAL
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(cls, interval, **kw):
|
||||
return _PGInterval(precision=interval.second_precision)
|
||||
|
||||
|
||||
class _PGTimeStamp(sqltypes.DateTime):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGDate(sqltypes.Date):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGTime(sqltypes.Time):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGInteger(sqltypes.Integer):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGSmallInteger(sqltypes.SmallInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGNullType(sqltypes.NullType):
|
||||
pass
|
||||
|
||||
|
||||
class _PGBigInteger(sqltypes.BigInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGBoolean(sqltypes.Boolean):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGARRAY(PGARRAY):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGOIDVECTOR(_SpaceVector, OIDVECTOR):
|
||||
pass
|
||||
|
||||
|
||||
class _Pg8000Range(ranges.AbstractSingleRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
pg8000_Range = dialect.dbapi.Range
|
||||
|
||||
def to_range(value):
|
||||
if isinstance(value, ranges.Range):
|
||||
value = pg8000_Range(
|
||||
value.lower, value.upper, value.bounds, value.empty
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is not None:
|
||||
value = ranges.Range(
|
||||
value.lower,
|
||||
value.upper,
|
||||
bounds=value.bounds,
|
||||
empty=value.is_empty,
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
|
||||
class _Pg8000MultiRange(ranges.AbstractMultiRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
pg8000_Range = dialect.dbapi.Range
|
||||
|
||||
def to_multirange(value):
|
||||
if isinstance(value, list):
|
||||
mr = []
|
||||
for v in value:
|
||||
if isinstance(v, ranges.Range):
|
||||
mr.append(
|
||||
pg8000_Range(v.lower, v.upper, v.bounds, v.empty)
|
||||
)
|
||||
else:
|
||||
mr.append(v)
|
||||
return mr
|
||||
else:
|
||||
return value
|
||||
|
||||
return to_multirange
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_multirange(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
ranges.Range(
|
||||
v.lower, v.upper, bounds=v.bounds, empty=v.is_empty
|
||||
)
|
||||
for v in value
|
||||
)
|
||||
|
||||
return to_multirange
|
||||
|
||||
|
||||
_server_side_id = util.counter()
|
||||
|
||||
|
||||
class PGExecutionContext_pg8000(PGExecutionContext):
|
||||
def create_server_side_cursor(self):
|
||||
ident = "c_%s_%s" % (hex(id(self))[2:], hex(_server_side_id())[2:])
|
||||
return ServerSideCursor(self._dbapi_connection.cursor(), ident)
|
||||
|
||||
def pre_exec(self):
|
||||
if not self.compiled:
|
||||
return
|
||||
|
||||
|
||||
class ServerSideCursor:
|
||||
server_side = True
|
||||
|
||||
def __init__(self, cursor, ident):
|
||||
self.ident = ident
|
||||
self.cursor = cursor
|
||||
|
||||
@property
|
||||
def connection(self):
|
||||
return self.cursor.connection
|
||||
|
||||
@property
|
||||
def rowcount(self):
|
||||
return self.cursor.rowcount
|
||||
|
||||
@property
|
||||
def description(self):
|
||||
return self.cursor.description
|
||||
|
||||
def execute(self, operation, args=(), stream=None):
|
||||
op = "DECLARE " + self.ident + " NO SCROLL CURSOR FOR " + operation
|
||||
self.cursor.execute(op, args, stream=stream)
|
||||
return self
|
||||
|
||||
def executemany(self, operation, param_sets):
|
||||
self.cursor.executemany(operation, param_sets)
|
||||
return self
|
||||
|
||||
def fetchone(self):
|
||||
self.cursor.execute("FETCH FORWARD 1 FROM " + self.ident)
|
||||
return self.cursor.fetchone()
|
||||
|
||||
def fetchmany(self, num=None):
|
||||
if num is None:
|
||||
return self.fetchall()
|
||||
else:
|
||||
self.cursor.execute(
|
||||
"FETCH FORWARD " + str(int(num)) + " FROM " + self.ident
|
||||
)
|
||||
return self.cursor.fetchall()
|
||||
|
||||
def fetchall(self):
|
||||
self.cursor.execute("FETCH FORWARD ALL FROM " + self.ident)
|
||||
return self.cursor.fetchall()
|
||||
|
||||
def close(self):
|
||||
self.cursor.execute("CLOSE " + self.ident)
|
||||
self.cursor.close()
|
||||
|
||||
def setinputsizes(self, *sizes):
|
||||
self.cursor.setinputsizes(*sizes)
|
||||
|
||||
def setoutputsize(self, size, column=None):
|
||||
pass
|
||||
|
||||
|
||||
class PGCompiler_pg8000(PGCompiler):
|
||||
def visit_mod_binary(self, binary, operator, **kw):
|
||||
return (
|
||||
self.process(binary.left, **kw)
|
||||
+ " %% "
|
||||
+ self.process(binary.right, **kw)
|
||||
)
|
||||
|
||||
|
||||
class PGIdentifierPreparer_pg8000(PGIdentifierPreparer):
|
||||
def __init__(self, *args, **kwargs):
|
||||
PGIdentifierPreparer.__init__(self, *args, **kwargs)
|
||||
self._double_percents = False
|
||||
|
||||
|
||||
class PGDialect_pg8000(PGDialect):
|
||||
driver = "pg8000"
|
||||
supports_statement_cache = True
|
||||
|
||||
supports_unicode_statements = True
|
||||
|
||||
supports_unicode_binds = True
|
||||
|
||||
default_paramstyle = "format"
|
||||
supports_sane_multi_rowcount = True
|
||||
execution_ctx_cls = PGExecutionContext_pg8000
|
||||
statement_compiler = PGCompiler_pg8000
|
||||
preparer = PGIdentifierPreparer_pg8000
|
||||
supports_server_side_cursors = True
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
# reversed as of pg8000 1.16.6. 1.16.5 and lower
|
||||
# are no longer compatible
|
||||
description_encoding = None
|
||||
# description_encoding = "use_encoding"
|
||||
|
||||
colspecs = util.update_copy(
|
||||
PGDialect.colspecs,
|
||||
{
|
||||
sqltypes.String: _PGString,
|
||||
sqltypes.Numeric: _PGNumericNoBind,
|
||||
sqltypes.Float: _PGFloat,
|
||||
sqltypes.JSON: _PGJSON,
|
||||
sqltypes.Boolean: _PGBoolean,
|
||||
sqltypes.NullType: _PGNullType,
|
||||
JSONB: _PGJSONB,
|
||||
CITEXT: CITEXT,
|
||||
sqltypes.JSON.JSONPathType: _PGJSONPathType,
|
||||
sqltypes.JSON.JSONIndexType: _PGJSONIndexType,
|
||||
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
|
||||
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
|
||||
sqltypes.Interval: _PGInterval,
|
||||
INTERVAL: _PGInterval,
|
||||
sqltypes.DateTime: _PGTimeStamp,
|
||||
sqltypes.DateTime: _PGTimeStamp,
|
||||
sqltypes.Date: _PGDate,
|
||||
sqltypes.Time: _PGTime,
|
||||
sqltypes.Integer: _PGInteger,
|
||||
sqltypes.SmallInteger: _PGSmallInteger,
|
||||
sqltypes.BigInteger: _PGBigInteger,
|
||||
sqltypes.Enum: _PGEnum,
|
||||
sqltypes.ARRAY: _PGARRAY,
|
||||
OIDVECTOR: _PGOIDVECTOR,
|
||||
ranges.INT4RANGE: _Pg8000Range,
|
||||
ranges.INT8RANGE: _Pg8000Range,
|
||||
ranges.NUMRANGE: _Pg8000Range,
|
||||
ranges.DATERANGE: _Pg8000Range,
|
||||
ranges.TSRANGE: _Pg8000Range,
|
||||
ranges.TSTZRANGE: _Pg8000Range,
|
||||
ranges.INT4MULTIRANGE: _Pg8000MultiRange,
|
||||
ranges.INT8MULTIRANGE: _Pg8000MultiRange,
|
||||
ranges.NUMMULTIRANGE: _Pg8000MultiRange,
|
||||
ranges.DATEMULTIRANGE: _Pg8000MultiRange,
|
||||
ranges.TSMULTIRANGE: _Pg8000MultiRange,
|
||||
ranges.TSTZMULTIRANGE: _Pg8000MultiRange,
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, client_encoding=None, **kwargs):
|
||||
PGDialect.__init__(self, **kwargs)
|
||||
self.client_encoding = client_encoding
|
||||
|
||||
if self._dbapi_version < (1, 16, 6):
|
||||
raise NotImplementedError("pg8000 1.16.6 or greater is required")
|
||||
|
||||
if self._native_inet_types:
|
||||
raise NotImplementedError(
|
||||
"The pg8000 dialect does not fully implement "
|
||||
"ipaddress type handling; INET is supported by default, "
|
||||
"CIDR is not"
|
||||
)
|
||||
|
||||
@util.memoized_property
|
||||
def _dbapi_version(self):
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
return tuple(
|
||||
[
|
||||
int(x)
|
||||
for x in re.findall(
|
||||
r"(\d+)(?:[-\.]?|$)", self.dbapi.__version__
|
||||
)
|
||||
]
|
||||
)
|
||||
else:
|
||||
return (99, 99, 99)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("pg8000")
|
||||
|
||||
def create_connect_args(self, url):
|
||||
opts = url.translate_connect_args(username="user")
|
||||
if "port" in opts:
|
||||
opts["port"] = int(opts["port"])
|
||||
opts.update(url.query)
|
||||
return ([], opts)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.InterfaceError) and "network error" in str(
|
||||
e
|
||||
):
|
||||
# new as of pg8000 1.19.0 for broken connections
|
||||
return True
|
||||
|
||||
# connection was closed normally
|
||||
return "connection is closed" in str(e)
|
||||
|
||||
def get_isolation_level_values(self, dbapi_connection):
|
||||
return (
|
||||
"AUTOCOMMIT",
|
||||
"READ COMMITTED",
|
||||
"READ UNCOMMITTED",
|
||||
"REPEATABLE READ",
|
||||
"SERIALIZABLE",
|
||||
)
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
level = level.replace("_", " ")
|
||||
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.autocommit = True
|
||||
else:
|
||||
dbapi_connection.autocommit = False
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(
|
||||
"SET SESSION CHARACTERISTICS AS TRANSACTION "
|
||||
f"ISOLATION LEVEL {level}"
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute(
|
||||
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
|
||||
% ("READ ONLY" if value else "READ WRITE")
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_readonly(self, connection):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute("show transaction_read_only")
|
||||
val = cursor.fetchone()[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return val == "on"
|
||||
|
||||
def set_deferrable(self, connection, value):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute(
|
||||
"SET SESSION CHARACTERISTICS AS TRANSACTION %s"
|
||||
% ("DEFERRABLE" if value else "NOT DEFERRABLE")
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
def get_deferrable(self, connection):
|
||||
cursor = connection.cursor()
|
||||
try:
|
||||
cursor.execute("show transaction_deferrable")
|
||||
val = cursor.fetchone()[0]
|
||||
finally:
|
||||
cursor.close()
|
||||
|
||||
return val == "on"
|
||||
|
||||
def _set_client_encoding(self, dbapi_connection, client_encoding):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute(
|
||||
f"""SET CLIENT_ENCODING TO '{
|
||||
client_encoding.replace("'", "''")
|
||||
}'"""
|
||||
)
|
||||
cursor.execute("COMMIT")
|
||||
cursor.close()
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.connection.tpc_begin((0, xid, ""))
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.connection.tpc_prepare()
|
||||
|
||||
def do_rollback_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
connection.connection.tpc_rollback((0, xid, ""))
|
||||
|
||||
def do_commit_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
connection.connection.tpc_commit((0, xid, ""))
|
||||
|
||||
def do_recover_twophase(self, connection):
|
||||
return [row[1] for row in connection.connection.tpc_recover()]
|
||||
|
||||
def on_connect(self):
|
||||
fns = []
|
||||
|
||||
def on_connect(conn):
|
||||
conn.py_types[quoted_name] = conn.py_types[str]
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self.client_encoding is not None:
|
||||
|
||||
def on_connect(conn):
|
||||
self._set_client_encoding(conn, self.client_encoding)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self._native_inet_types is False:
|
||||
|
||||
def on_connect(conn):
|
||||
# inet
|
||||
conn.register_in_adapter(869, lambda s: s)
|
||||
|
||||
# cidr
|
||||
conn.register_in_adapter(650, lambda s: s)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self._json_deserializer:
|
||||
|
||||
def on_connect(conn):
|
||||
# json
|
||||
conn.register_in_adapter(114, self._json_deserializer)
|
||||
|
||||
# jsonb
|
||||
conn.register_in_adapter(3802, self._json_deserializer)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if len(fns) > 0:
|
||||
|
||||
def on_connect(conn):
|
||||
for fn in fns:
|
||||
fn(conn)
|
||||
|
||||
return on_connect
|
||||
else:
|
||||
return None
|
||||
|
||||
@util.memoized_property
|
||||
def _dialect_specific_select_one(self):
|
||||
return ";"
|
||||
|
||||
|
||||
dialect = PGDialect_pg8000
|
|
@ -0,0 +1,294 @@
|
|||
# dialects/postgresql/pg_catalog.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from .array import ARRAY
|
||||
from .types import OID
|
||||
from .types import REGCLASS
|
||||
from ... import Column
|
||||
from ... import func
|
||||
from ... import MetaData
|
||||
from ... import Table
|
||||
from ...types import BigInteger
|
||||
from ...types import Boolean
|
||||
from ...types import CHAR
|
||||
from ...types import Float
|
||||
from ...types import Integer
|
||||
from ...types import SmallInteger
|
||||
from ...types import String
|
||||
from ...types import Text
|
||||
from ...types import TypeDecorator
|
||||
|
||||
|
||||
# types
|
||||
class NAME(TypeDecorator):
|
||||
impl = String(64, collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class PG_NODE_TREE(TypeDecorator):
|
||||
impl = Text(collation="C")
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class INT2VECTOR(TypeDecorator):
|
||||
impl = ARRAY(SmallInteger)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class OIDVECTOR(TypeDecorator):
|
||||
impl = ARRAY(OID)
|
||||
cache_ok = True
|
||||
|
||||
|
||||
class _SpaceVector:
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return value
|
||||
return [int(p) for p in value.split(" ")]
|
||||
|
||||
return process
|
||||
|
||||
|
||||
REGPROC = REGCLASS # seems an alias
|
||||
|
||||
# functions
|
||||
_pg_cat = func.pg_catalog
|
||||
quote_ident = _pg_cat.quote_ident
|
||||
pg_table_is_visible = _pg_cat.pg_table_is_visible
|
||||
pg_type_is_visible = _pg_cat.pg_type_is_visible
|
||||
pg_get_viewdef = _pg_cat.pg_get_viewdef
|
||||
pg_get_serial_sequence = _pg_cat.pg_get_serial_sequence
|
||||
format_type = _pg_cat.format_type
|
||||
pg_get_expr = _pg_cat.pg_get_expr
|
||||
pg_get_constraintdef = _pg_cat.pg_get_constraintdef
|
||||
pg_get_indexdef = _pg_cat.pg_get_indexdef
|
||||
|
||||
# constants
|
||||
RELKINDS_TABLE_NO_FOREIGN = ("r", "p")
|
||||
RELKINDS_TABLE = RELKINDS_TABLE_NO_FOREIGN + ("f",)
|
||||
RELKINDS_VIEW = ("v",)
|
||||
RELKINDS_MAT_VIEW = ("m",)
|
||||
RELKINDS_ALL_TABLE_LIKE = RELKINDS_TABLE + RELKINDS_VIEW + RELKINDS_MAT_VIEW
|
||||
|
||||
# tables
|
||||
pg_catalog_meta = MetaData()
|
||||
|
||||
pg_namespace = Table(
|
||||
"pg_namespace",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID),
|
||||
Column("nspname", NAME),
|
||||
Column("nspowner", OID),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_class = Table(
|
||||
"pg_class",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("relname", NAME),
|
||||
Column("relnamespace", OID),
|
||||
Column("reltype", OID),
|
||||
Column("reloftype", OID),
|
||||
Column("relowner", OID),
|
||||
Column("relam", OID),
|
||||
Column("relfilenode", OID),
|
||||
Column("reltablespace", OID),
|
||||
Column("relpages", Integer),
|
||||
Column("reltuples", Float),
|
||||
Column("relallvisible", Integer, info={"server_version": (9, 2)}),
|
||||
Column("reltoastrelid", OID),
|
||||
Column("relhasindex", Boolean),
|
||||
Column("relisshared", Boolean),
|
||||
Column("relpersistence", CHAR, info={"server_version": (9, 1)}),
|
||||
Column("relkind", CHAR),
|
||||
Column("relnatts", SmallInteger),
|
||||
Column("relchecks", SmallInteger),
|
||||
Column("relhasrules", Boolean),
|
||||
Column("relhastriggers", Boolean),
|
||||
Column("relhassubclass", Boolean),
|
||||
Column("relrowsecurity", Boolean),
|
||||
Column("relforcerowsecurity", Boolean, info={"server_version": (9, 5)}),
|
||||
Column("relispopulated", Boolean, info={"server_version": (9, 3)}),
|
||||
Column("relreplident", CHAR, info={"server_version": (9, 4)}),
|
||||
Column("relispartition", Boolean, info={"server_version": (10,)}),
|
||||
Column("relrewrite", OID, info={"server_version": (11,)}),
|
||||
Column("reloptions", ARRAY(Text)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_type = Table(
|
||||
"pg_type",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("typname", NAME),
|
||||
Column("typnamespace", OID),
|
||||
Column("typowner", OID),
|
||||
Column("typlen", SmallInteger),
|
||||
Column("typbyval", Boolean),
|
||||
Column("typtype", CHAR),
|
||||
Column("typcategory", CHAR),
|
||||
Column("typispreferred", Boolean),
|
||||
Column("typisdefined", Boolean),
|
||||
Column("typdelim", CHAR),
|
||||
Column("typrelid", OID),
|
||||
Column("typelem", OID),
|
||||
Column("typarray", OID),
|
||||
Column("typinput", REGPROC),
|
||||
Column("typoutput", REGPROC),
|
||||
Column("typreceive", REGPROC),
|
||||
Column("typsend", REGPROC),
|
||||
Column("typmodin", REGPROC),
|
||||
Column("typmodout", REGPROC),
|
||||
Column("typanalyze", REGPROC),
|
||||
Column("typalign", CHAR),
|
||||
Column("typstorage", CHAR),
|
||||
Column("typnotnull", Boolean),
|
||||
Column("typbasetype", OID),
|
||||
Column("typtypmod", Integer),
|
||||
Column("typndims", Integer),
|
||||
Column("typcollation", OID, info={"server_version": (9, 1)}),
|
||||
Column("typdefault", Text),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_index = Table(
|
||||
"pg_index",
|
||||
pg_catalog_meta,
|
||||
Column("indexrelid", OID),
|
||||
Column("indrelid", OID),
|
||||
Column("indnatts", SmallInteger),
|
||||
Column("indnkeyatts", SmallInteger, info={"server_version": (11,)}),
|
||||
Column("indisunique", Boolean),
|
||||
Column("indnullsnotdistinct", Boolean, info={"server_version": (15,)}),
|
||||
Column("indisprimary", Boolean),
|
||||
Column("indisexclusion", Boolean, info={"server_version": (9, 1)}),
|
||||
Column("indimmediate", Boolean),
|
||||
Column("indisclustered", Boolean),
|
||||
Column("indisvalid", Boolean),
|
||||
Column("indcheckxmin", Boolean),
|
||||
Column("indisready", Boolean),
|
||||
Column("indislive", Boolean, info={"server_version": (9, 3)}), # 9.3
|
||||
Column("indisreplident", Boolean),
|
||||
Column("indkey", INT2VECTOR),
|
||||
Column("indcollation", OIDVECTOR, info={"server_version": (9, 1)}), # 9.1
|
||||
Column("indclass", OIDVECTOR),
|
||||
Column("indoption", INT2VECTOR),
|
||||
Column("indexprs", PG_NODE_TREE),
|
||||
Column("indpred", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_attribute = Table(
|
||||
"pg_attribute",
|
||||
pg_catalog_meta,
|
||||
Column("attrelid", OID),
|
||||
Column("attname", NAME),
|
||||
Column("atttypid", OID),
|
||||
Column("attstattarget", Integer),
|
||||
Column("attlen", SmallInteger),
|
||||
Column("attnum", SmallInteger),
|
||||
Column("attndims", Integer),
|
||||
Column("attcacheoff", Integer),
|
||||
Column("atttypmod", Integer),
|
||||
Column("attbyval", Boolean),
|
||||
Column("attstorage", CHAR),
|
||||
Column("attalign", CHAR),
|
||||
Column("attnotnull", Boolean),
|
||||
Column("atthasdef", Boolean),
|
||||
Column("atthasmissing", Boolean, info={"server_version": (11,)}),
|
||||
Column("attidentity", CHAR, info={"server_version": (10,)}),
|
||||
Column("attgenerated", CHAR, info={"server_version": (12,)}),
|
||||
Column("attisdropped", Boolean),
|
||||
Column("attislocal", Boolean),
|
||||
Column("attinhcount", Integer),
|
||||
Column("attcollation", OID, info={"server_version": (9, 1)}),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_constraint = Table(
|
||||
"pg_constraint",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID), # 9.3
|
||||
Column("conname", NAME),
|
||||
Column("connamespace", OID),
|
||||
Column("contype", CHAR),
|
||||
Column("condeferrable", Boolean),
|
||||
Column("condeferred", Boolean),
|
||||
Column("convalidated", Boolean, info={"server_version": (9, 1)}),
|
||||
Column("conrelid", OID),
|
||||
Column("contypid", OID),
|
||||
Column("conindid", OID),
|
||||
Column("conparentid", OID, info={"server_version": (11,)}),
|
||||
Column("confrelid", OID),
|
||||
Column("confupdtype", CHAR),
|
||||
Column("confdeltype", CHAR),
|
||||
Column("confmatchtype", CHAR),
|
||||
Column("conislocal", Boolean),
|
||||
Column("coninhcount", Integer),
|
||||
Column("connoinherit", Boolean, info={"server_version": (9, 2)}),
|
||||
Column("conkey", ARRAY(SmallInteger)),
|
||||
Column("confkey", ARRAY(SmallInteger)),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_sequence = Table(
|
||||
"pg_sequence",
|
||||
pg_catalog_meta,
|
||||
Column("seqrelid", OID),
|
||||
Column("seqtypid", OID),
|
||||
Column("seqstart", BigInteger),
|
||||
Column("seqincrement", BigInteger),
|
||||
Column("seqmax", BigInteger),
|
||||
Column("seqmin", BigInteger),
|
||||
Column("seqcache", BigInteger),
|
||||
Column("seqcycle", Boolean),
|
||||
schema="pg_catalog",
|
||||
info={"server_version": (10,)},
|
||||
)
|
||||
|
||||
pg_attrdef = Table(
|
||||
"pg_attrdef",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("adrelid", OID),
|
||||
Column("adnum", SmallInteger),
|
||||
Column("adbin", PG_NODE_TREE),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_description = Table(
|
||||
"pg_description",
|
||||
pg_catalog_meta,
|
||||
Column("objoid", OID),
|
||||
Column("classoid", OID),
|
||||
Column("objsubid", Integer),
|
||||
Column("description", Text(collation="C")),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_enum = Table(
|
||||
"pg_enum",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("enumtypid", OID),
|
||||
Column("enumsortorder", Float(), info={"server_version": (9, 1)}),
|
||||
Column("enumlabel", NAME),
|
||||
schema="pg_catalog",
|
||||
)
|
||||
|
||||
pg_am = Table(
|
||||
"pg_am",
|
||||
pg_catalog_meta,
|
||||
Column("oid", OID, info={"server_version": (9, 3)}),
|
||||
Column("amname", NAME),
|
||||
Column("amhandler", REGPROC, info={"server_version": (9, 6)}),
|
||||
Column("amtype", CHAR, info={"server_version": (9, 6)}),
|
||||
schema="pg_catalog",
|
||||
)
|
|
@ -0,0 +1,175 @@
|
|||
# dialects/postgresql/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import time
|
||||
|
||||
from ... import exc
|
||||
from ... import inspect
|
||||
from ... import text
|
||||
from ...testing import warn_test_suite
|
||||
from ...testing.provision import create_db
|
||||
from ...testing.provision import drop_all_schema_objects_post_tables
|
||||
from ...testing.provision import drop_all_schema_objects_pre_tables
|
||||
from ...testing.provision import drop_db
|
||||
from ...testing.provision import log
|
||||
from ...testing.provision import post_configure_engine
|
||||
from ...testing.provision import prepare_for_drop_tables
|
||||
from ...testing.provision import set_default_schema_on_connection
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
from ...testing.provision import upsert
|
||||
|
||||
|
||||
@create_db.for_db("postgresql")
|
||||
def _pg_create_db(cfg, eng, ident):
|
||||
template_db = cfg.options.postgresql_templatedb
|
||||
|
||||
with eng.execution_options(isolation_level="AUTOCOMMIT").begin() as conn:
|
||||
if not template_db:
|
||||
template_db = conn.exec_driver_sql(
|
||||
"select current_database()"
|
||||
).scalar()
|
||||
|
||||
attempt = 0
|
||||
while True:
|
||||
try:
|
||||
conn.exec_driver_sql(
|
||||
"CREATE DATABASE %s TEMPLATE %s" % (ident, template_db)
|
||||
)
|
||||
except exc.OperationalError as err:
|
||||
attempt += 1
|
||||
if attempt >= 3:
|
||||
raise
|
||||
if "accessed by other users" in str(err):
|
||||
log.info(
|
||||
"Waiting to create %s, URI %r, "
|
||||
"template DB %s is in use sleeping for .5",
|
||||
ident,
|
||||
eng.url,
|
||||
template_db,
|
||||
)
|
||||
time.sleep(0.5)
|
||||
except:
|
||||
raise
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@drop_db.for_db("postgresql")
|
||||
def _pg_drop_db(cfg, eng, ident):
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
with conn.begin():
|
||||
conn.execute(
|
||||
text(
|
||||
"select pg_terminate_backend(pid) from pg_stat_activity "
|
||||
"where usename=current_user and pid != pg_backend_pid() "
|
||||
"and datname=:dname"
|
||||
),
|
||||
dict(dname=ident),
|
||||
)
|
||||
conn.exec_driver_sql("DROP DATABASE %s" % ident)
|
||||
|
||||
|
||||
@temp_table_keyword_args.for_db("postgresql")
|
||||
def _postgresql_temp_table_keyword_args(cfg, eng):
|
||||
return {"prefixes": ["TEMPORARY"]}
|
||||
|
||||
|
||||
@set_default_schema_on_connection.for_db("postgresql")
|
||||
def _postgresql_set_default_schema_on_connection(
|
||||
cfg, dbapi_connection, schema_name
|
||||
):
|
||||
existing_autocommit = dbapi_connection.autocommit
|
||||
dbapi_connection.autocommit = True
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("SET SESSION search_path='%s'" % schema_name)
|
||||
cursor.close()
|
||||
dbapi_connection.autocommit = existing_autocommit
|
||||
|
||||
|
||||
@drop_all_schema_objects_pre_tables.for_db("postgresql")
|
||||
def drop_all_schema_objects_pre_tables(cfg, eng):
|
||||
with eng.connect().execution_options(isolation_level="AUTOCOMMIT") as conn:
|
||||
for xid in conn.exec_driver_sql(
|
||||
"select gid from pg_prepared_xacts"
|
||||
).scalars():
|
||||
conn.execute("ROLLBACK PREPARED '%s'" % xid)
|
||||
|
||||
|
||||
@drop_all_schema_objects_post_tables.for_db("postgresql")
|
||||
def drop_all_schema_objects_post_tables(cfg, eng):
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
inspector = inspect(eng)
|
||||
with eng.begin() as conn:
|
||||
for enum in inspector.get_enums("*"):
|
||||
conn.execute(
|
||||
postgresql.DropEnumType(
|
||||
postgresql.ENUM(name=enum["name"], schema=enum["schema"])
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@prepare_for_drop_tables.for_db("postgresql")
|
||||
def prepare_for_drop_tables(config, connection):
|
||||
"""Ensure there are no locks on the current username/database."""
|
||||
|
||||
result = connection.exec_driver_sql(
|
||||
"select pid, state, wait_event_type, query "
|
||||
# "select pg_terminate_backend(pid), state, wait_event_type "
|
||||
"from pg_stat_activity where "
|
||||
"usename=current_user "
|
||||
"and datname=current_database() and state='idle in transaction' "
|
||||
"and pid != pg_backend_pid()"
|
||||
)
|
||||
rows = result.all() # noqa
|
||||
if rows:
|
||||
warn_test_suite(
|
||||
"PostgreSQL may not be able to DROP tables due to "
|
||||
"idle in transaction: %s"
|
||||
% ("; ".join(row._mapping["query"] for row in rows))
|
||||
)
|
||||
|
||||
|
||||
@upsert.for_db("postgresql")
|
||||
def _upsert(
|
||||
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
|
||||
):
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
|
||||
stmt = insert(table)
|
||||
|
||||
table_pk = inspect(table).selectable
|
||||
|
||||
if set_lambda:
|
||||
stmt = stmt.on_conflict_do_update(
|
||||
index_elements=table_pk.primary_key, set_=set_lambda(stmt.excluded)
|
||||
)
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
|
||||
stmt = stmt.returning(
|
||||
*returning, sort_by_parameter_order=sort_by_parameter_order
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
_extensions = [
|
||||
("citext", (13,)),
|
||||
("hstore", (13,)),
|
||||
]
|
||||
|
||||
|
||||
@post_configure_engine.for_db("postgresql")
|
||||
def _create_citext_extension(url, engine, follower_ident):
|
||||
with engine.connect() as conn:
|
||||
for extension, min_version in _extensions:
|
||||
if conn.dialect.server_version_info >= min_version:
|
||||
conn.execute(
|
||||
text(f"CREATE EXTENSION IF NOT EXISTS {extension}")
|
||||
)
|
||||
conn.commit()
|
|
@ -0,0 +1,749 @@
|
|||
# dialects/postgresql/psycopg.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: postgresql+psycopg
|
||||
:name: psycopg (a.k.a. psycopg 3)
|
||||
:dbapi: psycopg
|
||||
:connectstring: postgresql+psycopg://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://pypi.org/project/psycopg/
|
||||
|
||||
``psycopg`` is the package and module name for version 3 of the ``psycopg``
|
||||
database driver, formerly known as ``psycopg2``. This driver is different
|
||||
enough from its ``psycopg2`` predecessor that SQLAlchemy supports it
|
||||
via a totally separate dialect; support for ``psycopg2`` is expected to remain
|
||||
for as long as that package continues to function for modern Python versions,
|
||||
and also remains the default dialect for the ``postgresql://`` dialect
|
||||
series.
|
||||
|
||||
The SQLAlchemy ``psycopg`` dialect provides both a sync and an async
|
||||
implementation under the same dialect name. The proper version is
|
||||
selected depending on how the engine is created:
|
||||
|
||||
* calling :func:`_sa.create_engine` with ``postgresql+psycopg://...`` will
|
||||
automatically select the sync version, e.g.::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
sync_engine = create_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
* calling :func:`_asyncio.create_async_engine` with
|
||||
``postgresql+psycopg://...`` will automatically select the async version,
|
||||
e.g.::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg://scott:tiger@localhost/test")
|
||||
|
||||
The asyncio version of the dialect may also be specified explicitly using the
|
||||
``psycopg_async`` suffix, as::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
asyncio_engine = create_async_engine("postgresql+psycopg_async://scott:tiger@localhost/test")
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_psycopg2` - The SQLAlchemy ``psycopg``
|
||||
dialect shares most of its behavior with the ``psycopg2`` dialect.
|
||||
Further documentation is available there.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from . import ranges
|
||||
from ._psycopg_common import _PGDialect_common_psycopg
|
||||
from ._psycopg_common import _PGExecutionContext_common_psycopg
|
||||
from .base import INTERVAL
|
||||
from .base import PGCompiler
|
||||
from .base import PGIdentifierPreparer
|
||||
from .base import REGCONFIG
|
||||
from .json import JSON
|
||||
from .json import JSONB
|
||||
from .json import JSONPathType
|
||||
from .types import CITEXT
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...engine import AdaptedConnection
|
||||
from ...sql import sqltypes
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Iterable
|
||||
|
||||
from psycopg import AsyncConnection
|
||||
|
||||
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
|
||||
|
||||
|
||||
class _PGString(sqltypes.String):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGREGCONFIG(REGCONFIG):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Json)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
render_bind_cast = True
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
return self._make_bind_processor(None, dialect._psycopg_Jsonb)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSONIntIndexType(sqltypes.JSON.JSONIntIndexType):
|
||||
__visit_name__ = "json_int_index"
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGJSONStrIndexType(sqltypes.JSON.JSONStrIndexType):
|
||||
__visit_name__ = "json_str_index"
|
||||
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGJSONPathType(JSONPathType):
|
||||
pass
|
||||
|
||||
|
||||
class _PGInterval(INTERVAL):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGTimeStamp(sqltypes.DateTime):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGDate(sqltypes.Date):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGTime(sqltypes.Time):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGInteger(sqltypes.Integer):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGSmallInteger(sqltypes.SmallInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGNullType(sqltypes.NullType):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGBigInteger(sqltypes.BigInteger):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PGBoolean(sqltypes.Boolean):
|
||||
render_bind_cast = True
|
||||
|
||||
|
||||
class _PsycopgRange(ranges.AbstractSingleRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
|
||||
|
||||
def to_range(value):
|
||||
if isinstance(value, ranges.Range):
|
||||
value = psycopg_Range(
|
||||
value.lower, value.upper, value.bounds, value.empty
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is not None:
|
||||
value = ranges.Range(
|
||||
value._lower,
|
||||
value._upper,
|
||||
bounds=value._bounds if value._bounds else "[)",
|
||||
empty=not value._bounds,
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
|
||||
class _PsycopgMultiRange(ranges.AbstractMultiRangeImpl):
|
||||
def bind_processor(self, dialect):
|
||||
psycopg_Range = cast(PGDialect_psycopg, dialect)._psycopg_Range
|
||||
psycopg_Multirange = cast(
|
||||
PGDialect_psycopg, dialect
|
||||
)._psycopg_Multirange
|
||||
|
||||
NoneType = type(None)
|
||||
|
||||
def to_range(value):
|
||||
if isinstance(value, (str, NoneType, psycopg_Multirange)):
|
||||
return value
|
||||
|
||||
return psycopg_Multirange(
|
||||
[
|
||||
psycopg_Range(
|
||||
element.lower,
|
||||
element.upper,
|
||||
element.bounds,
|
||||
element.empty,
|
||||
)
|
||||
for element in cast("Iterable[ranges.Range]", value)
|
||||
]
|
||||
)
|
||||
|
||||
return to_range
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is None:
|
||||
return None
|
||||
else:
|
||||
return ranges.MultiRange(
|
||||
ranges.Range(
|
||||
elem._lower,
|
||||
elem._upper,
|
||||
bounds=elem._bounds if elem._bounds else "[)",
|
||||
empty=not elem._bounds,
|
||||
)
|
||||
for elem in value
|
||||
)
|
||||
|
||||
return to_range
|
||||
|
||||
|
||||
class PGExecutionContext_psycopg(_PGExecutionContext_common_psycopg):
|
||||
pass
|
||||
|
||||
|
||||
class PGCompiler_psycopg(PGCompiler):
|
||||
pass
|
||||
|
||||
|
||||
class PGIdentifierPreparer_psycopg(PGIdentifierPreparer):
|
||||
pass
|
||||
|
||||
|
||||
def _log_notices(diagnostic):
|
||||
logger.info("%s: %s", diagnostic.severity, diagnostic.message_primary)
|
||||
|
||||
|
||||
class PGDialect_psycopg(_PGDialect_common_psycopg):
|
||||
driver = "psycopg"
|
||||
|
||||
supports_statement_cache = True
|
||||
supports_server_side_cursors = True
|
||||
default_paramstyle = "pyformat"
|
||||
supports_sane_multi_rowcount = True
|
||||
|
||||
execution_ctx_cls = PGExecutionContext_psycopg
|
||||
statement_compiler = PGCompiler_psycopg
|
||||
preparer = PGIdentifierPreparer_psycopg
|
||||
psycopg_version = (0, 0)
|
||||
|
||||
_has_native_hstore = True
|
||||
_psycopg_adapters_map = None
|
||||
|
||||
colspecs = util.update_copy(
|
||||
_PGDialect_common_psycopg.colspecs,
|
||||
{
|
||||
sqltypes.String: _PGString,
|
||||
REGCONFIG: _PGREGCONFIG,
|
||||
JSON: _PGJSON,
|
||||
CITEXT: CITEXT,
|
||||
sqltypes.JSON: _PGJSON,
|
||||
JSONB: _PGJSONB,
|
||||
sqltypes.JSON.JSONPathType: _PGJSONPathType,
|
||||
sqltypes.JSON.JSONIntIndexType: _PGJSONIntIndexType,
|
||||
sqltypes.JSON.JSONStrIndexType: _PGJSONStrIndexType,
|
||||
sqltypes.Interval: _PGInterval,
|
||||
INTERVAL: _PGInterval,
|
||||
sqltypes.Date: _PGDate,
|
||||
sqltypes.DateTime: _PGTimeStamp,
|
||||
sqltypes.Time: _PGTime,
|
||||
sqltypes.Integer: _PGInteger,
|
||||
sqltypes.SmallInteger: _PGSmallInteger,
|
||||
sqltypes.BigInteger: _PGBigInteger,
|
||||
ranges.AbstractSingleRange: _PsycopgRange,
|
||||
ranges.AbstractMultiRange: _PsycopgMultiRange,
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if self.dbapi:
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
|
||||
if m:
|
||||
self.psycopg_version = tuple(
|
||||
int(x) for x in m.group(1, 2, 3) if x is not None
|
||||
)
|
||||
|
||||
if self.psycopg_version < (3, 0, 2):
|
||||
raise ImportError(
|
||||
"psycopg version 3.0.2 or higher is required."
|
||||
)
|
||||
|
||||
from psycopg.adapt import AdaptersMap
|
||||
|
||||
self._psycopg_adapters_map = adapters_map = AdaptersMap(
|
||||
self.dbapi.adapters
|
||||
)
|
||||
|
||||
if self._native_inet_types is False:
|
||||
import psycopg.types.string
|
||||
|
||||
adapters_map.register_loader(
|
||||
"inet", psycopg.types.string.TextLoader
|
||||
)
|
||||
adapters_map.register_loader(
|
||||
"cidr", psycopg.types.string.TextLoader
|
||||
)
|
||||
|
||||
if self._json_deserializer:
|
||||
from psycopg.types.json import set_json_loads
|
||||
|
||||
set_json_loads(self._json_deserializer, adapters_map)
|
||||
|
||||
if self._json_serializer:
|
||||
from psycopg.types.json import set_json_dumps
|
||||
|
||||
set_json_dumps(self._json_serializer, adapters_map)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
# see https://github.com/psycopg/psycopg/issues/83
|
||||
cargs, cparams = super().create_connect_args(url)
|
||||
|
||||
if self._psycopg_adapters_map:
|
||||
cparams["context"] = self._psycopg_adapters_map
|
||||
if self.client_encoding is not None:
|
||||
cparams["client_encoding"] = self.client_encoding
|
||||
return cargs, cparams
|
||||
|
||||
def _type_info_fetch(self, connection, name):
|
||||
from psycopg.types import TypeInfo
|
||||
|
||||
return TypeInfo.fetch(connection.connection.driver_connection, name)
|
||||
|
||||
def initialize(self, connection):
|
||||
super().initialize(connection)
|
||||
|
||||
# PGDialect.initialize() checks server version for <= 8.2 and sets
|
||||
# this flag to False if so
|
||||
if not self.insert_returning:
|
||||
self.insert_executemany_returning = False
|
||||
|
||||
# HSTORE can't be registered until we have a connection so that
|
||||
# we can look up its OID, so we set up this adapter in
|
||||
# initialize()
|
||||
if self.use_native_hstore:
|
||||
info = self._type_info_fetch(connection, "hstore")
|
||||
self._has_native_hstore = info is not None
|
||||
if self._has_native_hstore:
|
||||
from psycopg.types.hstore import register_hstore
|
||||
|
||||
# register the adapter for connections made subsequent to
|
||||
# this one
|
||||
register_hstore(info, self._psycopg_adapters_map)
|
||||
|
||||
# register the adapter for this connection
|
||||
register_hstore(info, connection.connection)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import psycopg
|
||||
|
||||
return psycopg
|
||||
|
||||
@classmethod
|
||||
def get_async_dialect_cls(cls, url):
|
||||
return PGDialectAsync_psycopg
|
||||
|
||||
@util.memoized_property
|
||||
def _isolation_lookup(self):
|
||||
return {
|
||||
"READ COMMITTED": self.dbapi.IsolationLevel.READ_COMMITTED,
|
||||
"READ UNCOMMITTED": self.dbapi.IsolationLevel.READ_UNCOMMITTED,
|
||||
"REPEATABLE READ": self.dbapi.IsolationLevel.REPEATABLE_READ,
|
||||
"SERIALIZABLE": self.dbapi.IsolationLevel.SERIALIZABLE,
|
||||
}
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg_Json(self):
|
||||
from psycopg.types import json
|
||||
|
||||
return json.Json
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg_Jsonb(self):
|
||||
from psycopg.types import json
|
||||
|
||||
return json.Jsonb
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg_TransactionStatus(self):
|
||||
from psycopg.pq import TransactionStatus
|
||||
|
||||
return TransactionStatus
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg_Range(self):
|
||||
from psycopg.types.range import Range
|
||||
|
||||
return Range
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg_Multirange(self):
|
||||
from psycopg.types.multirange import Multirange
|
||||
|
||||
return Multirange
|
||||
|
||||
def _do_isolation_level(self, connection, autocommit, isolation_level):
|
||||
connection.autocommit = autocommit
|
||||
connection.isolation_level = isolation_level
|
||||
|
||||
def get_isolation_level(self, dbapi_connection):
|
||||
status_before = dbapi_connection.info.transaction_status
|
||||
value = super().get_isolation_level(dbapi_connection)
|
||||
|
||||
# don't rely on psycopg providing enum symbols, compare with
|
||||
# eq/ne
|
||||
if status_before == self._psycopg_TransactionStatus.IDLE:
|
||||
dbapi_connection.rollback()
|
||||
return value
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
self._do_isolation_level(
|
||||
dbapi_connection, autocommit=True, isolation_level=None
|
||||
)
|
||||
else:
|
||||
self._do_isolation_level(
|
||||
dbapi_connection,
|
||||
autocommit=False,
|
||||
isolation_level=self._isolation_lookup[level],
|
||||
)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.read_only = value
|
||||
|
||||
def get_readonly(self, connection):
|
||||
return connection.read_only
|
||||
|
||||
def on_connect(self):
|
||||
def notices(conn):
|
||||
conn.add_notice_handler(_log_notices)
|
||||
|
||||
fns = [notices]
|
||||
|
||||
if self.isolation_level is not None:
|
||||
|
||||
def on_connect(conn):
|
||||
self.set_isolation_level(conn, self.isolation_level)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
# fns always has the notices function
|
||||
def on_connect(conn):
|
||||
for fn in fns:
|
||||
fn(conn)
|
||||
|
||||
return on_connect
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.Error) and connection is not None:
|
||||
if connection.closed or connection.broken:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _do_prepared_twophase(self, connection, command, recover=False):
|
||||
dbapi_conn = connection.connection.dbapi_connection
|
||||
if (
|
||||
recover
|
||||
# don't rely on psycopg providing enum symbols, compare with
|
||||
# eq/ne
|
||||
or dbapi_conn.info.transaction_status
|
||||
!= self._psycopg_TransactionStatus.IDLE
|
||||
):
|
||||
dbapi_conn.rollback()
|
||||
before_autocommit = dbapi_conn.autocommit
|
||||
try:
|
||||
if not before_autocommit:
|
||||
self._do_autocommit(dbapi_conn, True)
|
||||
dbapi_conn.execute(command)
|
||||
finally:
|
||||
if not before_autocommit:
|
||||
self._do_autocommit(dbapi_conn, before_autocommit)
|
||||
|
||||
def do_rollback_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if is_prepared:
|
||||
self._do_prepared_twophase(
|
||||
connection, f"ROLLBACK PREPARED '{xid}'", recover=recover
|
||||
)
|
||||
else:
|
||||
self.do_rollback(connection.connection)
|
||||
|
||||
def do_commit_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
if is_prepared:
|
||||
self._do_prepared_twophase(
|
||||
connection, f"COMMIT PREPARED '{xid}'", recover=recover
|
||||
)
|
||||
else:
|
||||
self.do_commit(connection.connection)
|
||||
|
||||
@util.memoized_property
|
||||
def _dialect_specific_select_one(self):
|
||||
return ";"
|
||||
|
||||
|
||||
class AsyncAdapt_psycopg_cursor:
|
||||
__slots__ = ("_cursor", "await_", "_rows")
|
||||
|
||||
_psycopg_ExecStatus = None
|
||||
|
||||
def __init__(self, cursor, await_) -> None:
|
||||
self._cursor = cursor
|
||||
self.await_ = await_
|
||||
self._rows = []
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._cursor, name)
|
||||
|
||||
@property
|
||||
def arraysize(self):
|
||||
return self._cursor.arraysize
|
||||
|
||||
@arraysize.setter
|
||||
def arraysize(self, value):
|
||||
self._cursor.arraysize = value
|
||||
|
||||
def close(self):
|
||||
self._rows.clear()
|
||||
# Normal cursor just call _close() in a non-sync way.
|
||||
self._cursor._close()
|
||||
|
||||
def execute(self, query, params=None, **kw):
|
||||
result = self.await_(self._cursor.execute(query, params, **kw))
|
||||
# sqlalchemy result is not async, so need to pull all rows here
|
||||
res = self._cursor.pgresult
|
||||
|
||||
# don't rely on psycopg providing enum symbols, compare with
|
||||
# eq/ne
|
||||
if res and res.status == self._psycopg_ExecStatus.TUPLES_OK:
|
||||
rows = self.await_(self._cursor.fetchall())
|
||||
if not isinstance(rows, list):
|
||||
self._rows = list(rows)
|
||||
else:
|
||||
self._rows = rows
|
||||
return result
|
||||
|
||||
def executemany(self, query, params_seq):
|
||||
return self.await_(self._cursor.executemany(query, params_seq))
|
||||
|
||||
def __iter__(self):
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
# TODO: try to avoid pop(0) on a list
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self._cursor.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows
|
||||
self._rows = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_psycopg_ss_cursor(AsyncAdapt_psycopg_cursor):
|
||||
def execute(self, query, params=None, **kw):
|
||||
self.await_(self._cursor.execute(query, params, **kw))
|
||||
return self
|
||||
|
||||
def close(self):
|
||||
self.await_(self._cursor.close())
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=0):
|
||||
return self.await_(self._cursor.fetchmany(size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
def __iter__(self):
|
||||
iterator = self._cursor.__aiter__()
|
||||
while True:
|
||||
try:
|
||||
yield self.await_(iterator.__anext__())
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
||||
class AsyncAdapt_psycopg_connection(AdaptedConnection):
|
||||
_connection: AsyncConnection
|
||||
__slots__ = ()
|
||||
await_ = staticmethod(await_only)
|
||||
|
||||
def __init__(self, connection) -> None:
|
||||
self._connection = connection
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self._connection, name)
|
||||
|
||||
def execute(self, query, params=None, **kw):
|
||||
cursor = self.await_(self._connection.execute(query, params, **kw))
|
||||
return AsyncAdapt_psycopg_cursor(cursor, self.await_)
|
||||
|
||||
def cursor(self, *args, **kw):
|
||||
cursor = self._connection.cursor(*args, **kw)
|
||||
if hasattr(cursor, "name"):
|
||||
return AsyncAdapt_psycopg_ss_cursor(cursor, self.await_)
|
||||
else:
|
||||
return AsyncAdapt_psycopg_cursor(cursor, self.await_)
|
||||
|
||||
def commit(self):
|
||||
self.await_(self._connection.commit())
|
||||
|
||||
def rollback(self):
|
||||
self.await_(self._connection.rollback())
|
||||
|
||||
def close(self):
|
||||
self.await_(self._connection.close())
|
||||
|
||||
@property
|
||||
def autocommit(self):
|
||||
return self._connection.autocommit
|
||||
|
||||
@autocommit.setter
|
||||
def autocommit(self, value):
|
||||
self.set_autocommit(value)
|
||||
|
||||
def set_autocommit(self, value):
|
||||
self.await_(self._connection.set_autocommit(value))
|
||||
|
||||
def set_isolation_level(self, value):
|
||||
self.await_(self._connection.set_isolation_level(value))
|
||||
|
||||
def set_read_only(self, value):
|
||||
self.await_(self._connection.set_read_only(value))
|
||||
|
||||
def set_deferrable(self, value):
|
||||
self.await_(self._connection.set_deferrable(value))
|
||||
|
||||
|
||||
class AsyncAdaptFallback_psycopg_connection(AsyncAdapt_psycopg_connection):
|
||||
__slots__ = ()
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class PsycopgAdaptDBAPI:
|
||||
def __init__(self, psycopg) -> None:
|
||||
self.psycopg = psycopg
|
||||
|
||||
for k, v in self.psycopg.__dict__.items():
|
||||
if k != "connect":
|
||||
self.__dict__[k] = v
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
creator_fn = kw.pop(
|
||||
"async_creator_fn", self.psycopg.AsyncConnection.connect
|
||||
)
|
||||
if util.asbool(async_fallback):
|
||||
return AsyncAdaptFallback_psycopg_connection(
|
||||
await_fallback(creator_fn(*arg, **kw))
|
||||
)
|
||||
else:
|
||||
return AsyncAdapt_psycopg_connection(
|
||||
await_only(creator_fn(*arg, **kw))
|
||||
)
|
||||
|
||||
|
||||
class PGDialectAsync_psycopg(PGDialect_psycopg):
|
||||
is_async = True
|
||||
supports_statement_cache = True
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import psycopg
|
||||
from psycopg.pq import ExecStatus
|
||||
|
||||
AsyncAdapt_psycopg_cursor._psycopg_ExecStatus = ExecStatus
|
||||
|
||||
return PsycopgAdaptDBAPI(psycopg)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
async_fallback = url.query.get("async_fallback", False)
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return pool.FallbackAsyncAdaptedQueuePool
|
||||
else:
|
||||
return pool.AsyncAdaptedQueuePool
|
||||
|
||||
def _type_info_fetch(self, connection, name):
|
||||
from psycopg.types import TypeInfo
|
||||
|
||||
adapted = connection.connection
|
||||
return adapted.await_(TypeInfo.fetch(adapted.driver_connection, name))
|
||||
|
||||
def _do_isolation_level(self, connection, autocommit, isolation_level):
|
||||
connection.set_autocommit(autocommit)
|
||||
connection.set_isolation_level(isolation_level)
|
||||
|
||||
def _do_autocommit(self, connection, value):
|
||||
connection.set_autocommit(value)
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.set_read_only(value)
|
||||
|
||||
def set_deferrable(self, connection, value):
|
||||
connection.set_deferrable(value)
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg
|
||||
dialect_async = PGDialectAsync_psycopg
|
|
@ -0,0 +1,876 @@
|
|||
# dialects/postgresql/psycopg2.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: postgresql+psycopg2
|
||||
:name: psycopg2
|
||||
:dbapi: psycopg2
|
||||
:connectstring: postgresql+psycopg2://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://pypi.org/project/psycopg2/
|
||||
|
||||
.. _psycopg2_toplevel:
|
||||
|
||||
psycopg2 Connect Arguments
|
||||
--------------------------
|
||||
|
||||
Keyword arguments that are specific to the SQLAlchemy psycopg2 dialect
|
||||
may be passed to :func:`_sa.create_engine()`, and include the following:
|
||||
|
||||
|
||||
* ``isolation_level``: This option, available for all PostgreSQL dialects,
|
||||
includes the ``AUTOCOMMIT`` isolation level when using the psycopg2
|
||||
dialect. This option sets the **default** isolation level for the
|
||||
connection that is set immediately upon connection to the database before
|
||||
the connection is pooled. This option is generally superseded by the more
|
||||
modern :paramref:`_engine.Connection.execution_options.isolation_level`
|
||||
execution option, detailed at :ref:`dbapi_autocommit`.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`psycopg2_isolation_level`
|
||||
|
||||
:ref:`dbapi_autocommit`
|
||||
|
||||
|
||||
* ``client_encoding``: sets the client encoding in a libpq-agnostic way,
|
||||
using psycopg2's ``set_client_encoding()`` method.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`psycopg2_unicode`
|
||||
|
||||
|
||||
* ``executemany_mode``, ``executemany_batch_page_size``,
|
||||
``executemany_values_page_size``: Allows use of psycopg2
|
||||
extensions for optimizing "executemany"-style queries. See the referenced
|
||||
section below for details.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`psycopg2_executemany_mode`
|
||||
|
||||
.. tip::
|
||||
|
||||
The above keyword arguments are **dialect** keyword arguments, meaning
|
||||
that they are passed as explicit keyword arguments to :func:`_sa.create_engine()`::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
isolation_level="SERIALIZABLE",
|
||||
)
|
||||
|
||||
These should not be confused with **DBAPI** connect arguments, which
|
||||
are passed as part of the :paramref:`_sa.create_engine.connect_args`
|
||||
dictionary and/or are passed in the URL query string, as detailed in
|
||||
the section :ref:`custom_dbapi_args`.
|
||||
|
||||
.. _psycopg2_ssl:
|
||||
|
||||
SSL Connections
|
||||
---------------
|
||||
|
||||
The psycopg2 module has a connection argument named ``sslmode`` for
|
||||
controlling its behavior regarding secure (SSL) connections. The default is
|
||||
``sslmode=prefer``; it will attempt an SSL connection and if that fails it
|
||||
will fall back to an unencrypted connection. ``sslmode=require`` may be used
|
||||
to ensure that only secure connections are established. Consult the
|
||||
psycopg2 / libpq documentation for further options that are available.
|
||||
|
||||
Note that ``sslmode`` is specific to psycopg2 so it is included in the
|
||||
connection URI::
|
||||
|
||||
engine = sa.create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@192.168.0.199:5432/test?sslmode=require"
|
||||
)
|
||||
|
||||
|
||||
Unix Domain Connections
|
||||
------------------------
|
||||
|
||||
psycopg2 supports connecting via Unix domain connections. When the ``host``
|
||||
portion of the URL is omitted, SQLAlchemy passes ``None`` to psycopg2,
|
||||
which specifies Unix-domain communication rather than TCP/IP communication::
|
||||
|
||||
create_engine("postgresql+psycopg2://user:password@/dbname")
|
||||
|
||||
By default, the socket file used is to connect to a Unix-domain socket
|
||||
in ``/tmp``, or whatever socket directory was specified when PostgreSQL
|
||||
was built. This value can be overridden by passing a pathname to psycopg2,
|
||||
using ``host`` as an additional keyword argument::
|
||||
|
||||
create_engine("postgresql+psycopg2://user:password@/dbname?host=/var/lib/postgresql")
|
||||
|
||||
.. warning:: The format accepted here allows for a hostname in the main URL
|
||||
in addition to the "host" query string argument. **When using this URL
|
||||
format, the initial host is silently ignored**. That is, this URL::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://user:password@myhost1/dbname?host=myhost2")
|
||||
|
||||
Above, the hostname ``myhost1`` is **silently ignored and discarded.** The
|
||||
host which is connected is the ``myhost2`` host.
|
||||
|
||||
This is to maintain some degree of compatibility with PostgreSQL's own URL
|
||||
format which has been tested to behave the same way and for which tools like
|
||||
PifPaf hardcode two hostnames.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`PQconnectdbParams \
|
||||
<https://www.postgresql.org/docs/current/static/libpq-connect.html#LIBPQ-PQCONNECTDBPARAMS>`_
|
||||
|
||||
.. _psycopg2_multi_host:
|
||||
|
||||
Specifying multiple fallback hosts
|
||||
-----------------------------------
|
||||
|
||||
psycopg2 supports multiple connection points in the connection string.
|
||||
When the ``host`` parameter is used multiple times in the query section of
|
||||
the URL, SQLAlchemy will create a single string of the host and port
|
||||
information provided to make the connections. Tokens may consist of
|
||||
``host::port`` or just ``host``; in the latter case, the default port
|
||||
is selected by libpq. In the example below, three host connections
|
||||
are specified, for ``HostA::PortA``, ``HostB`` connecting to the default port,
|
||||
and ``HostC::PortC``::
|
||||
|
||||
create_engine(
|
||||
"postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC"
|
||||
)
|
||||
|
||||
As an alternative, libpq query string format also may be used; this specifies
|
||||
``host`` and ``port`` as single query string arguments with comma-separated
|
||||
lists - the default port can be chosen by indicating an empty value
|
||||
in the comma separated list::
|
||||
|
||||
create_engine(
|
||||
"postgresql+psycopg2://user:password@/dbname?host=HostA,HostB,HostC&port=PortA,,PortC"
|
||||
)
|
||||
|
||||
With either URL style, connections to each host is attempted based on a
|
||||
configurable strategy, which may be configured using the libpq
|
||||
``target_session_attrs`` parameter. Per libpq this defaults to ``any``
|
||||
which indicates a connection to each host is then attempted until a connection is successful.
|
||||
Other strategies include ``primary``, ``prefer-standby``, etc. The complete
|
||||
list is documented by PostgreSQL at
|
||||
`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_.
|
||||
|
||||
For example, to indicate two hosts using the ``primary`` strategy::
|
||||
|
||||
create_engine(
|
||||
"postgresql+psycopg2://user:password@/dbname?host=HostA:PortA&host=HostB&host=HostC:PortC&target_session_attrs=primary"
|
||||
)
|
||||
|
||||
.. versionchanged:: 1.4.40 Port specification in psycopg2 multiple host format
|
||||
is repaired, previously ports were not correctly interpreted in this context.
|
||||
libpq comma-separated format is also now supported.
|
||||
|
||||
.. versionadded:: 1.3.20 Support for multiple hosts in PostgreSQL connection
|
||||
string.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`libpq connection strings <https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING>`_ - please refer
|
||||
to this section in the libpq documentation for complete background on multiple host support.
|
||||
|
||||
|
||||
Empty DSN Connections / Environment Variable Connections
|
||||
---------------------------------------------------------
|
||||
|
||||
The psycopg2 DBAPI can connect to PostgreSQL by passing an empty DSN to the
|
||||
libpq client library, which by default indicates to connect to a localhost
|
||||
PostgreSQL database that is open for "trust" connections. This behavior can be
|
||||
further tailored using a particular set of environment variables which are
|
||||
prefixed with ``PG_...``, which are consumed by ``libpq`` to take the place of
|
||||
any or all elements of the connection string.
|
||||
|
||||
For this form, the URL can be passed without any elements other than the
|
||||
initial scheme::
|
||||
|
||||
engine = create_engine('postgresql+psycopg2://')
|
||||
|
||||
In the above form, a blank "dsn" string is passed to the ``psycopg2.connect()``
|
||||
function which in turn represents an empty DSN passed to libpq.
|
||||
|
||||
.. versionadded:: 1.3.2 support for parameter-less connections with psycopg2.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Environment Variables\
|
||||
<https://www.postgresql.org/docs/current/libpq-envars.html>`_ -
|
||||
PostgreSQL documentation on how to use ``PG_...``
|
||||
environment variables for connections.
|
||||
|
||||
.. _psycopg2_execution_options:
|
||||
|
||||
Per-Statement/Connection Execution Options
|
||||
-------------------------------------------
|
||||
|
||||
The following DBAPI-specific options are respected when used with
|
||||
:meth:`_engine.Connection.execution_options`,
|
||||
:meth:`.Executable.execution_options`,
|
||||
:meth:`_query.Query.execution_options`,
|
||||
in addition to those not specific to DBAPIs:
|
||||
|
||||
* ``isolation_level`` - Set the transaction isolation level for the lifespan
|
||||
of a :class:`_engine.Connection` (can only be set on a connection,
|
||||
not a statement
|
||||
or query). See :ref:`psycopg2_isolation_level`.
|
||||
|
||||
* ``stream_results`` - Enable or disable usage of psycopg2 server side
|
||||
cursors - this feature makes use of "named" cursors in combination with
|
||||
special result handling methods so that result rows are not fully buffered.
|
||||
Defaults to False, meaning cursors are buffered by default.
|
||||
|
||||
* ``max_row_buffer`` - when using ``stream_results``, an integer value that
|
||||
specifies the maximum number of rows to buffer at a time. This is
|
||||
interpreted by the :class:`.BufferedRowCursorResult`, and if omitted the
|
||||
buffer will grow to ultimately store 1000 rows at a time.
|
||||
|
||||
.. versionchanged:: 1.4 The ``max_row_buffer`` size can now be greater than
|
||||
1000, and the buffer will grow to that size.
|
||||
|
||||
.. _psycopg2_batch_mode:
|
||||
|
||||
.. _psycopg2_executemany_mode:
|
||||
|
||||
Psycopg2 Fast Execution Helpers
|
||||
-------------------------------
|
||||
|
||||
Modern versions of psycopg2 include a feature known as
|
||||
`Fast Execution Helpers \
|
||||
<https://initd.org/psycopg/docs/extras.html#fast-execution-helpers>`_, which
|
||||
have been shown in benchmarking to improve psycopg2's executemany()
|
||||
performance, primarily with INSERT statements, by at least
|
||||
an order of magnitude.
|
||||
|
||||
SQLAlchemy implements a native form of the "insert many values"
|
||||
handler that will rewrite a single-row INSERT statement to accommodate for
|
||||
many values at once within an extended VALUES clause; this handler is
|
||||
equivalent to psycopg2's ``execute_values()`` handler; an overview of this
|
||||
feature and its configuration are at :ref:`engine_insertmanyvalues`.
|
||||
|
||||
.. versionadded:: 2.0 Replaced psycopg2's ``execute_values()`` fast execution
|
||||
helper with a native SQLAlchemy mechanism known as
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>`.
|
||||
|
||||
The psycopg2 dialect retains the ability to use the psycopg2-specific
|
||||
``execute_batch()`` feature, although it is not expected that this is a widely
|
||||
used feature. The use of this extension may be enabled using the
|
||||
``executemany_mode`` flag which may be passed to :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode='values_plus_batch')
|
||||
|
||||
|
||||
Possible options for ``executemany_mode`` include:
|
||||
|
||||
* ``values_only`` - this is the default value. SQLAlchemy's native
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>` handler is used for qualifying
|
||||
INSERT statements, assuming
|
||||
:paramref:`_sa.create_engine.use_insertmanyvalues` is left at
|
||||
its default value of ``True``. This handler rewrites simple
|
||||
INSERT statements to include multiple VALUES clauses so that many
|
||||
parameter sets can be inserted with one statement.
|
||||
|
||||
* ``'values_plus_batch'``- SQLAlchemy's native
|
||||
:ref:`insertmanyvalues <engine_insertmanyvalues>` handler is used for qualifying
|
||||
INSERT statements, assuming
|
||||
:paramref:`_sa.create_engine.use_insertmanyvalues` is left at its default
|
||||
value of ``True``. Then, psycopg2's ``execute_batch()`` handler is used for
|
||||
qualifying UPDATE and DELETE statements when executed with multiple parameter
|
||||
sets. When using this mode, the :attr:`_engine.CursorResult.rowcount`
|
||||
attribute will not contain a value for executemany-style executions against
|
||||
UPDATE and DELETE statements.
|
||||
|
||||
.. versionchanged:: 2.0 Removed the ``'batch'`` and ``'None'`` options
|
||||
from psycopg2 ``executemany_mode``. Control over batching for INSERT
|
||||
statements is now configured via the
|
||||
:paramref:`_sa.create_engine.use_insertmanyvalues` engine-level parameter.
|
||||
|
||||
The term "qualifying statements" refers to the statement being executed
|
||||
being a Core :func:`_expression.insert`, :func:`_expression.update`
|
||||
or :func:`_expression.delete` construct, and **not** a plain textual SQL
|
||||
string or one constructed using :func:`_expression.text`. It also may **not** be
|
||||
a special "extension" statement such as an "ON CONFLICT" "upsert" statement.
|
||||
When using the ORM, all insert/update/delete statements used by the ORM flush process
|
||||
are qualifying.
|
||||
|
||||
The "page size" for the psycopg2 "batch" strategy can be affected
|
||||
by using the ``executemany_batch_page_size`` parameter, which defaults to
|
||||
100.
|
||||
|
||||
For the "insertmanyvalues" feature, the page size can be controlled using the
|
||||
:paramref:`_sa.create_engine.insertmanyvalues_page_size` parameter,
|
||||
which defaults to 1000. An example of modifying both parameters
|
||||
is below::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://scott:tiger@host/dbname",
|
||||
executemany_mode='values_plus_batch',
|
||||
insertmanyvalues_page_size=5000, executemany_batch_page_size=500)
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`engine_insertmanyvalues` - background on "insertmanyvalues"
|
||||
|
||||
:ref:`tutorial_multiple_parameters` - General information on using the
|
||||
:class:`_engine.Connection`
|
||||
object to execute statements in such a way as to make
|
||||
use of the DBAPI ``.executemany()`` method.
|
||||
|
||||
|
||||
.. _psycopg2_unicode:
|
||||
|
||||
Unicode with Psycopg2
|
||||
----------------------
|
||||
|
||||
The psycopg2 DBAPI driver supports Unicode data transparently.
|
||||
|
||||
The client character encoding can be controlled for the psycopg2 dialect
|
||||
in the following ways:
|
||||
|
||||
* For PostgreSQL 9.1 and above, the ``client_encoding`` parameter may be
|
||||
passed in the database URL; this parameter is consumed by the underlying
|
||||
``libpq`` PostgreSQL client library::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://user:pass@host/dbname?client_encoding=utf8")
|
||||
|
||||
Alternatively, the above ``client_encoding`` value may be passed using
|
||||
:paramref:`_sa.create_engine.connect_args` for programmatic establishment with
|
||||
``libpq``::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
connect_args={'client_encoding': 'utf8'}
|
||||
)
|
||||
|
||||
* For all PostgreSQL versions, psycopg2 supports a client-side encoding
|
||||
value that will be passed to database connections when they are first
|
||||
established. The SQLAlchemy psycopg2 dialect supports this using the
|
||||
``client_encoding`` parameter passed to :func:`_sa.create_engine`::
|
||||
|
||||
engine = create_engine(
|
||||
"postgresql+psycopg2://user:pass@host/dbname",
|
||||
client_encoding="utf8"
|
||||
)
|
||||
|
||||
.. tip:: The above ``client_encoding`` parameter admittedly is very similar
|
||||
in appearance to usage of the parameter within the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary; the difference
|
||||
above is that the parameter is consumed by psycopg2 and is
|
||||
passed to the database connection using ``SET client_encoding TO
|
||||
'utf8'``; in the previously mentioned style, the parameter is instead
|
||||
passed through psycopg2 and consumed by the ``libpq`` library.
|
||||
|
||||
* A common way to set up client encoding with PostgreSQL databases is to
|
||||
ensure it is configured within the server-side postgresql.conf file;
|
||||
this is the recommended way to set encoding for a server that is
|
||||
consistently of one encoding in all databases::
|
||||
|
||||
# postgresql.conf file
|
||||
|
||||
# client_encoding = sql_ascii # actually, defaults to database
|
||||
# encoding
|
||||
client_encoding = utf8
|
||||
|
||||
|
||||
|
||||
Transactions
|
||||
------------
|
||||
|
||||
The psycopg2 dialect fully supports SAVEPOINT and two-phase commit operations.
|
||||
|
||||
.. _psycopg2_isolation_level:
|
||||
|
||||
Psycopg2 Transaction Isolation Level
|
||||
-------------------------------------
|
||||
|
||||
As discussed in :ref:`postgresql_isolation_level`,
|
||||
all PostgreSQL dialects support setting of transaction isolation level
|
||||
both via the ``isolation_level`` parameter passed to :func:`_sa.create_engine`
|
||||
,
|
||||
as well as the ``isolation_level`` argument used by
|
||||
:meth:`_engine.Connection.execution_options`. When using the psycopg2 dialect
|
||||
, these
|
||||
options make use of psycopg2's ``set_isolation_level()`` connection method,
|
||||
rather than emitting a PostgreSQL directive; this is because psycopg2's
|
||||
API-level setting is always emitted at the start of each transaction in any
|
||||
case.
|
||||
|
||||
The psycopg2 dialect supports these constants for isolation level:
|
||||
|
||||
* ``READ COMMITTED``
|
||||
* ``READ UNCOMMITTED``
|
||||
* ``REPEATABLE READ``
|
||||
* ``SERIALIZABLE``
|
||||
* ``AUTOCOMMIT``
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_isolation_level`
|
||||
|
||||
:ref:`pg8000_isolation_level`
|
||||
|
||||
|
||||
NOTICE logging
|
||||
---------------
|
||||
|
||||
The psycopg2 dialect will log PostgreSQL NOTICE messages
|
||||
via the ``sqlalchemy.dialects.postgresql`` logger. When this logger
|
||||
is set to the ``logging.INFO`` level, notice messages will be logged::
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
|
||||
Above, it is assumed that logging is configured externally. If this is not
|
||||
the case, configuration such as ``logging.basicConfig()`` must be utilized::
|
||||
|
||||
import logging
|
||||
|
||||
logging.basicConfig() # log messages to stdout
|
||||
logging.getLogger('sqlalchemy.dialects.postgresql').setLevel(logging.INFO)
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Logging HOWTO <https://docs.python.org/3/howto/logging.html>`_ - on the python.org website
|
||||
|
||||
.. _psycopg2_hstore:
|
||||
|
||||
HSTORE type
|
||||
------------
|
||||
|
||||
The ``psycopg2`` DBAPI includes an extension to natively handle marshalling of
|
||||
the HSTORE type. The SQLAlchemy psycopg2 dialect will enable this extension
|
||||
by default when psycopg2 version 2.4 or greater is used, and
|
||||
it is detected that the target database has the HSTORE type set up for use.
|
||||
In other words, when the dialect makes the first
|
||||
connection, a sequence like the following is performed:
|
||||
|
||||
1. Request the available HSTORE oids using
|
||||
``psycopg2.extras.HstoreAdapter.get_oids()``.
|
||||
If this function returns a list of HSTORE identifiers, we then determine
|
||||
that the ``HSTORE`` extension is present.
|
||||
This function is **skipped** if the version of psycopg2 installed is
|
||||
less than version 2.4.
|
||||
|
||||
2. If the ``use_native_hstore`` flag is at its default of ``True``, and
|
||||
we've detected that ``HSTORE`` oids are available, the
|
||||
``psycopg2.extensions.register_hstore()`` extension is invoked for all
|
||||
connections.
|
||||
|
||||
The ``register_hstore()`` extension has the effect of **all Python
|
||||
dictionaries being accepted as parameters regardless of the type of target
|
||||
column in SQL**. The dictionaries are converted by this extension into a
|
||||
textual HSTORE expression. If this behavior is not desired, disable the
|
||||
use of the hstore extension by setting ``use_native_hstore`` to ``False`` as
|
||||
follows::
|
||||
|
||||
engine = create_engine("postgresql+psycopg2://scott:tiger@localhost/test",
|
||||
use_native_hstore=False)
|
||||
|
||||
The ``HSTORE`` type is **still supported** when the
|
||||
``psycopg2.extensions.register_hstore()`` extension is not used. It merely
|
||||
means that the coercion between Python dictionaries and the HSTORE
|
||||
string format, on both the parameter side and the result side, will take
|
||||
place within SQLAlchemy's own marshalling logic, and not that of ``psycopg2``
|
||||
which may be more performant.
|
||||
|
||||
""" # noqa
|
||||
from __future__ import annotations
|
||||
|
||||
import collections.abc as collections_abc
|
||||
import logging
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from . import ranges
|
||||
from ._psycopg_common import _PGDialect_common_psycopg
|
||||
from ._psycopg_common import _PGExecutionContext_common_psycopg
|
||||
from .base import PGIdentifierPreparer
|
||||
from .json import JSON
|
||||
from .json import JSONB
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
from ...util import FastIntFlag
|
||||
from ...util import parse_user_argument_for_enum
|
||||
|
||||
logger = logging.getLogger("sqlalchemy.dialects.postgresql")
|
||||
|
||||
|
||||
class _PGJSON(JSON):
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _PGJSONB(JSONB):
|
||||
def result_processor(self, dialect, coltype):
|
||||
return None
|
||||
|
||||
|
||||
class _Psycopg2Range(ranges.AbstractSingleRangeImpl):
|
||||
_psycopg2_range_cls = "none"
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
psycopg2_Range = getattr(
|
||||
cast(PGDialect_psycopg2, dialect)._psycopg2_extras,
|
||||
self._psycopg2_range_cls,
|
||||
)
|
||||
|
||||
def to_range(value):
|
||||
if isinstance(value, ranges.Range):
|
||||
value = psycopg2_Range(
|
||||
value.lower, value.upper, value.bounds, value.empty
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def to_range(value):
|
||||
if value is not None:
|
||||
value = ranges.Range(
|
||||
value._lower,
|
||||
value._upper,
|
||||
bounds=value._bounds if value._bounds else "[)",
|
||||
empty=not value._bounds,
|
||||
)
|
||||
return value
|
||||
|
||||
return to_range
|
||||
|
||||
|
||||
class _Psycopg2NumericRange(_Psycopg2Range):
|
||||
_psycopg2_range_cls = "NumericRange"
|
||||
|
||||
|
||||
class _Psycopg2DateRange(_Psycopg2Range):
|
||||
_psycopg2_range_cls = "DateRange"
|
||||
|
||||
|
||||
class _Psycopg2DateTimeRange(_Psycopg2Range):
|
||||
_psycopg2_range_cls = "DateTimeRange"
|
||||
|
||||
|
||||
class _Psycopg2DateTimeTZRange(_Psycopg2Range):
|
||||
_psycopg2_range_cls = "DateTimeTZRange"
|
||||
|
||||
|
||||
class PGExecutionContext_psycopg2(_PGExecutionContext_common_psycopg):
|
||||
_psycopg2_fetched_rows = None
|
||||
|
||||
def post_exec(self):
|
||||
self._log_notices(self.cursor)
|
||||
|
||||
def _log_notices(self, cursor):
|
||||
# check also that notices is an iterable, after it's already
|
||||
# established that we will be iterating through it. This is to get
|
||||
# around test suites such as SQLAlchemy's using a Mock object for
|
||||
# cursor
|
||||
if not cursor.connection.notices or not isinstance(
|
||||
cursor.connection.notices, collections_abc.Iterable
|
||||
):
|
||||
return
|
||||
|
||||
for notice in cursor.connection.notices:
|
||||
# NOTICE messages have a
|
||||
# newline character at the end
|
||||
logger.info(notice.rstrip())
|
||||
|
||||
cursor.connection.notices[:] = []
|
||||
|
||||
|
||||
class PGIdentifierPreparer_psycopg2(PGIdentifierPreparer):
|
||||
pass
|
||||
|
||||
|
||||
class ExecutemanyMode(FastIntFlag):
|
||||
EXECUTEMANY_VALUES = 0
|
||||
EXECUTEMANY_VALUES_PLUS_BATCH = 1
|
||||
|
||||
|
||||
(
|
||||
EXECUTEMANY_VALUES,
|
||||
EXECUTEMANY_VALUES_PLUS_BATCH,
|
||||
) = ExecutemanyMode.__members__.values()
|
||||
|
||||
|
||||
class PGDialect_psycopg2(_PGDialect_common_psycopg):
|
||||
driver = "psycopg2"
|
||||
|
||||
supports_statement_cache = True
|
||||
supports_server_side_cursors = True
|
||||
|
||||
default_paramstyle = "pyformat"
|
||||
# set to true based on psycopg2 version
|
||||
supports_sane_multi_rowcount = False
|
||||
execution_ctx_cls = PGExecutionContext_psycopg2
|
||||
preparer = PGIdentifierPreparer_psycopg2
|
||||
psycopg2_version = (0, 0)
|
||||
use_insertmanyvalues_wo_returning = True
|
||||
|
||||
returns_native_bytes = False
|
||||
|
||||
_has_native_hstore = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
_PGDialect_common_psycopg.colspecs,
|
||||
{
|
||||
JSON: _PGJSON,
|
||||
sqltypes.JSON: _PGJSON,
|
||||
JSONB: _PGJSONB,
|
||||
ranges.INT4RANGE: _Psycopg2NumericRange,
|
||||
ranges.INT8RANGE: _Psycopg2NumericRange,
|
||||
ranges.NUMRANGE: _Psycopg2NumericRange,
|
||||
ranges.DATERANGE: _Psycopg2DateRange,
|
||||
ranges.TSRANGE: _Psycopg2DateTimeRange,
|
||||
ranges.TSTZRANGE: _Psycopg2DateTimeTZRange,
|
||||
},
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executemany_mode="values_only",
|
||||
executemany_batch_page_size=100,
|
||||
**kwargs,
|
||||
):
|
||||
_PGDialect_common_psycopg.__init__(self, **kwargs)
|
||||
|
||||
if self._native_inet_types:
|
||||
raise NotImplementedError(
|
||||
"The psycopg2 dialect does not implement "
|
||||
"ipaddress type handling; native_inet_types cannot be set "
|
||||
"to ``True`` when using this dialect."
|
||||
)
|
||||
|
||||
# Parse executemany_mode argument, allowing it to be only one of the
|
||||
# symbol names
|
||||
self.executemany_mode = parse_user_argument_for_enum(
|
||||
executemany_mode,
|
||||
{
|
||||
EXECUTEMANY_VALUES: ["values_only"],
|
||||
EXECUTEMANY_VALUES_PLUS_BATCH: ["values_plus_batch"],
|
||||
},
|
||||
"executemany_mode",
|
||||
)
|
||||
|
||||
self.executemany_batch_page_size = executemany_batch_page_size
|
||||
|
||||
if self.dbapi and hasattr(self.dbapi, "__version__"):
|
||||
m = re.match(r"(\d+)\.(\d+)(?:\.(\d+))?", self.dbapi.__version__)
|
||||
if m:
|
||||
self.psycopg2_version = tuple(
|
||||
int(x) for x in m.group(1, 2, 3) if x is not None
|
||||
)
|
||||
|
||||
if self.psycopg2_version < (2, 7):
|
||||
raise ImportError(
|
||||
"psycopg2 version 2.7 or higher is required."
|
||||
)
|
||||
|
||||
def initialize(self, connection):
|
||||
super().initialize(connection)
|
||||
self._has_native_hstore = (
|
||||
self.use_native_hstore
|
||||
and self._hstore_oids(connection.connection.dbapi_connection)
|
||||
is not None
|
||||
)
|
||||
|
||||
self.supports_sane_multi_rowcount = (
|
||||
self.executemany_mode is not EXECUTEMANY_VALUES_PLUS_BATCH
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
import psycopg2
|
||||
|
||||
return psycopg2
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg2_extensions(cls):
|
||||
from psycopg2 import extensions
|
||||
|
||||
return extensions
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg2_extras(cls):
|
||||
from psycopg2 import extras
|
||||
|
||||
return extras
|
||||
|
||||
@util.memoized_property
|
||||
def _isolation_lookup(self):
|
||||
extensions = self._psycopg2_extensions
|
||||
return {
|
||||
"AUTOCOMMIT": extensions.ISOLATION_LEVEL_AUTOCOMMIT,
|
||||
"READ COMMITTED": extensions.ISOLATION_LEVEL_READ_COMMITTED,
|
||||
"READ UNCOMMITTED": extensions.ISOLATION_LEVEL_READ_UNCOMMITTED,
|
||||
"REPEATABLE READ": extensions.ISOLATION_LEVEL_REPEATABLE_READ,
|
||||
"SERIALIZABLE": extensions.ISOLATION_LEVEL_SERIALIZABLE,
|
||||
}
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
dbapi_connection.set_isolation_level(self._isolation_lookup[level])
|
||||
|
||||
def set_readonly(self, connection, value):
|
||||
connection.readonly = value
|
||||
|
||||
def get_readonly(self, connection):
|
||||
return connection.readonly
|
||||
|
||||
def set_deferrable(self, connection, value):
|
||||
connection.deferrable = value
|
||||
|
||||
def get_deferrable(self, connection):
|
||||
return connection.deferrable
|
||||
|
||||
def on_connect(self):
|
||||
extras = self._psycopg2_extras
|
||||
|
||||
fns = []
|
||||
if self.client_encoding is not None:
|
||||
|
||||
def on_connect(dbapi_conn):
|
||||
dbapi_conn.set_client_encoding(self.client_encoding)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self.dbapi:
|
||||
|
||||
def on_connect(dbapi_conn):
|
||||
extras.register_uuid(None, dbapi_conn)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self.dbapi and self.use_native_hstore:
|
||||
|
||||
def on_connect(dbapi_conn):
|
||||
hstore_oids = self._hstore_oids(dbapi_conn)
|
||||
if hstore_oids is not None:
|
||||
oid, array_oid = hstore_oids
|
||||
kw = {"oid": oid}
|
||||
kw["array_oid"] = array_oid
|
||||
extras.register_hstore(dbapi_conn, **kw)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if self.dbapi and self._json_deserializer:
|
||||
|
||||
def on_connect(dbapi_conn):
|
||||
extras.register_default_json(
|
||||
dbapi_conn, loads=self._json_deserializer
|
||||
)
|
||||
extras.register_default_jsonb(
|
||||
dbapi_conn, loads=self._json_deserializer
|
||||
)
|
||||
|
||||
fns.append(on_connect)
|
||||
|
||||
if fns:
|
||||
|
||||
def on_connect(dbapi_conn):
|
||||
for fn in fns:
|
||||
fn(dbapi_conn)
|
||||
|
||||
return on_connect
|
||||
else:
|
||||
return None
|
||||
|
||||
def do_executemany(self, cursor, statement, parameters, context=None):
|
||||
if self.executemany_mode is EXECUTEMANY_VALUES_PLUS_BATCH:
|
||||
if self.executemany_batch_page_size:
|
||||
kwargs = {"page_size": self.executemany_batch_page_size}
|
||||
else:
|
||||
kwargs = {}
|
||||
self._psycopg2_extras.execute_batch(
|
||||
cursor, statement, parameters, **kwargs
|
||||
)
|
||||
else:
|
||||
cursor.executemany(statement, parameters)
|
||||
|
||||
def do_begin_twophase(self, connection, xid):
|
||||
connection.connection.tpc_begin(xid)
|
||||
|
||||
def do_prepare_twophase(self, connection, xid):
|
||||
connection.connection.tpc_prepare()
|
||||
|
||||
def _do_twophase(self, dbapi_conn, operation, xid, recover=False):
|
||||
if recover:
|
||||
if dbapi_conn.status != self._psycopg2_extensions.STATUS_READY:
|
||||
dbapi_conn.rollback()
|
||||
operation(xid)
|
||||
else:
|
||||
operation()
|
||||
|
||||
def do_rollback_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
dbapi_conn = connection.connection.dbapi_connection
|
||||
self._do_twophase(
|
||||
dbapi_conn, dbapi_conn.tpc_rollback, xid, recover=recover
|
||||
)
|
||||
|
||||
def do_commit_twophase(
|
||||
self, connection, xid, is_prepared=True, recover=False
|
||||
):
|
||||
dbapi_conn = connection.connection.dbapi_connection
|
||||
self._do_twophase(
|
||||
dbapi_conn, dbapi_conn.tpc_commit, xid, recover=recover
|
||||
)
|
||||
|
||||
@util.memoized_instancemethod
|
||||
def _hstore_oids(self, dbapi_connection):
|
||||
extras = self._psycopg2_extras
|
||||
oids = extras.HstoreAdapter.get_oids(dbapi_connection)
|
||||
if oids is not None and oids[0]:
|
||||
return oids[0:2]
|
||||
else:
|
||||
return None
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(e, self.dbapi.Error):
|
||||
# check the "closed" flag. this might not be
|
||||
# present on old psycopg2 versions. Also,
|
||||
# this flag doesn't actually help in a lot of disconnect
|
||||
# situations, so don't rely on it.
|
||||
if getattr(connection, "closed", False):
|
||||
return True
|
||||
|
||||
# checks based on strings. in the case that .closed
|
||||
# didn't cut it, fall back onto these.
|
||||
str_e = str(e).partition("\n")[0]
|
||||
for msg in [
|
||||
# these error messages from libpq: interfaces/libpq/fe-misc.c
|
||||
# and interfaces/libpq/fe-secure.c.
|
||||
"terminating connection",
|
||||
"closed the connection",
|
||||
"connection not open",
|
||||
"could not receive data from server",
|
||||
"could not send data to server",
|
||||
# psycopg2 client errors, psycopg2/connection.h,
|
||||
# psycopg2/cursor.h
|
||||
"connection already closed",
|
||||
"cursor already closed",
|
||||
# not sure where this path is originally from, it may
|
||||
# be obsolete. It really says "losed", not "closed".
|
||||
"losed the connection unexpectedly",
|
||||
# these can occur in newer SSL
|
||||
"connection has been closed unexpectedly",
|
||||
"SSL error: decryption failed or bad record mac",
|
||||
"SSL SYSCALL error: Bad file descriptor",
|
||||
"SSL SYSCALL error: EOF detected",
|
||||
"SSL SYSCALL error: Operation timed out",
|
||||
"SSL SYSCALL error: Bad address",
|
||||
]:
|
||||
idx = str_e.find(msg)
|
||||
if idx >= 0 and '"' not in str_e[:idx]:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg2
|
|
@ -0,0 +1,61 @@
|
|||
# dialects/postgresql/psycopg2cffi.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
r"""
|
||||
.. dialect:: postgresql+psycopg2cffi
|
||||
:name: psycopg2cffi
|
||||
:dbapi: psycopg2cffi
|
||||
:connectstring: postgresql+psycopg2cffi://user:password@host:port/dbname[?key=value&key=value...]
|
||||
:url: https://pypi.org/project/psycopg2cffi/
|
||||
|
||||
``psycopg2cffi`` is an adaptation of ``psycopg2``, using CFFI for the C
|
||||
layer. This makes it suitable for use in e.g. PyPy. Documentation
|
||||
is as per ``psycopg2``.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:mod:`sqlalchemy.dialects.postgresql.psycopg2`
|
||||
|
||||
""" # noqa
|
||||
from .psycopg2 import PGDialect_psycopg2
|
||||
from ... import util
|
||||
|
||||
|
||||
class PGDialect_psycopg2cffi(PGDialect_psycopg2):
|
||||
driver = "psycopg2cffi"
|
||||
supports_unicode_statements = True
|
||||
supports_statement_cache = True
|
||||
|
||||
# psycopg2cffi's first release is 2.5.0, but reports
|
||||
# __version__ as 2.4.4. Subsequent releases seem to have
|
||||
# fixed this.
|
||||
|
||||
FEATURE_VERSION_MAP = dict(
|
||||
native_json=(2, 4, 4),
|
||||
native_jsonb=(2, 7, 1),
|
||||
sane_multi_rowcount=(2, 4, 4),
|
||||
array_oid=(2, 4, 4),
|
||||
hstore_adapter=(2, 4, 4),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return __import__("psycopg2cffi")
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg2_extensions(cls):
|
||||
root = __import__("psycopg2cffi", fromlist=["extensions"])
|
||||
return root.extensions
|
||||
|
||||
@util.memoized_property
|
||||
def _psycopg2_extras(cls):
|
||||
root = __import__("psycopg2cffi", fromlist=["extras"])
|
||||
return root.extras
|
||||
|
||||
|
||||
dialect = PGDialect_psycopg2cffi
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,303 @@
|
|||
# dialects/postgresql/types.py
|
||||
# Copyright (C) 2013-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import overload
|
||||
from typing import Type
|
||||
from typing import TYPE_CHECKING
|
||||
from uuid import UUID as _python_UUID
|
||||
|
||||
from ...sql import sqltypes
|
||||
from ...sql import type_api
|
||||
from ...util.typing import Literal
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ...engine.interfaces import Dialect
|
||||
from ...sql.operators import OperatorType
|
||||
from ...sql.type_api import _LiteralProcessorType
|
||||
from ...sql.type_api import TypeEngine
|
||||
|
||||
_DECIMAL_TYPES = (1231, 1700)
|
||||
_FLOAT_TYPES = (700, 701, 1021, 1022)
|
||||
_INT_TYPES = (20, 21, 23, 26, 1005, 1007, 1016)
|
||||
|
||||
|
||||
class PGUuid(sqltypes.UUID[sqltypes._UUID_RETURN]):
|
||||
render_bind_cast = True
|
||||
render_literal_cast = True
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: PGUuid[_python_UUID], as_uuid: Literal[True] = ...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self: PGUuid[str], as_uuid: Literal[False] = ...
|
||||
) -> None: ...
|
||||
|
||||
def __init__(self, as_uuid: bool = True) -> None: ...
|
||||
|
||||
|
||||
class BYTEA(sqltypes.LargeBinary):
|
||||
__visit_name__ = "BYTEA"
|
||||
|
||||
|
||||
class INET(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "INET"
|
||||
|
||||
|
||||
PGInet = INET
|
||||
|
||||
|
||||
class CIDR(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "CIDR"
|
||||
|
||||
|
||||
PGCidr = CIDR
|
||||
|
||||
|
||||
class MACADDR(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR"
|
||||
|
||||
|
||||
PGMacAddr = MACADDR
|
||||
|
||||
|
||||
class MACADDR8(sqltypes.TypeEngine[str]):
|
||||
__visit_name__ = "MACADDR8"
|
||||
|
||||
|
||||
PGMacAddr8 = MACADDR8
|
||||
|
||||
|
||||
class MONEY(sqltypes.TypeEngine[str]):
|
||||
r"""Provide the PostgreSQL MONEY type.
|
||||
|
||||
Depending on driver, result rows using this type may return a
|
||||
string value which includes currency symbols.
|
||||
|
||||
For this reason, it may be preferable to provide conversion to a
|
||||
numerically-based currency datatype using :class:`_types.TypeDecorator`::
|
||||
|
||||
import re
|
||||
import decimal
|
||||
from sqlalchemy import Dialect
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
def process_result_value(
|
||||
self, value: Any, dialect: Dialect
|
||||
) -> None:
|
||||
if value is not None:
|
||||
# adjust this for the currency and numeric
|
||||
m = re.match(r"\$([\d.]+)", value)
|
||||
if m:
|
||||
value = decimal.Decimal(m.group(1))
|
||||
return value
|
||||
|
||||
Alternatively, the conversion may be applied as a CAST using
|
||||
the :meth:`_types.TypeDecorator.column_expression` method as follows::
|
||||
|
||||
import decimal
|
||||
from sqlalchemy import cast
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
class NumericMoney(TypeDecorator):
|
||||
impl = MONEY
|
||||
|
||||
def column_expression(self, column: Any):
|
||||
return cast(column, Numeric())
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "MONEY"
|
||||
|
||||
|
||||
class OID(sqltypes.TypeEngine[int]):
|
||||
"""Provide the PostgreSQL OID type."""
|
||||
|
||||
__visit_name__ = "OID"
|
||||
|
||||
|
||||
class REGCONFIG(sqltypes.TypeEngine[str]):
|
||||
"""Provide the PostgreSQL REGCONFIG type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "REGCONFIG"
|
||||
|
||||
|
||||
class TSQUERY(sqltypes.TypeEngine[str]):
|
||||
"""Provide the PostgreSQL TSQUERY type.
|
||||
|
||||
.. versionadded:: 2.0.0rc1
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "TSQUERY"
|
||||
|
||||
|
||||
class REGCLASS(sqltypes.TypeEngine[str]):
|
||||
"""Provide the PostgreSQL REGCLASS type.
|
||||
|
||||
.. versionadded:: 1.2.7
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "REGCLASS"
|
||||
|
||||
|
||||
class TIMESTAMP(sqltypes.TIMESTAMP):
|
||||
"""Provide the PostgreSQL TIMESTAMP type."""
|
||||
|
||||
__visit_name__ = "TIMESTAMP"
|
||||
|
||||
def __init__(
|
||||
self, timezone: bool = False, precision: Optional[int] = None
|
||||
) -> None:
|
||||
"""Construct a TIMESTAMP.
|
||||
|
||||
:param timezone: boolean value if timezone present, default False
|
||||
:param precision: optional integer precision value
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
super().__init__(timezone=timezone)
|
||||
self.precision = precision
|
||||
|
||||
|
||||
class TIME(sqltypes.TIME):
|
||||
"""PostgreSQL TIME type."""
|
||||
|
||||
__visit_name__ = "TIME"
|
||||
|
||||
def __init__(
|
||||
self, timezone: bool = False, precision: Optional[int] = None
|
||||
) -> None:
|
||||
"""Construct a TIME.
|
||||
|
||||
:param timezone: boolean value if timezone present, default False
|
||||
:param precision: optional integer precision value
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
"""
|
||||
super().__init__(timezone=timezone)
|
||||
self.precision = precision
|
||||
|
||||
|
||||
class INTERVAL(type_api.NativeForEmulated, sqltypes._AbstractInterval):
|
||||
"""PostgreSQL INTERVAL type."""
|
||||
|
||||
__visit_name__ = "INTERVAL"
|
||||
native = True
|
||||
|
||||
def __init__(
|
||||
self, precision: Optional[int] = None, fields: Optional[str] = None
|
||||
) -> None:
|
||||
"""Construct an INTERVAL.
|
||||
|
||||
:param precision: optional integer precision value
|
||||
:param fields: string fields specifier. allows storage of fields
|
||||
to be limited, such as ``"YEAR"``, ``"MONTH"``, ``"DAY TO HOUR"``,
|
||||
etc.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
|
||||
"""
|
||||
self.precision = precision
|
||||
self.fields = fields
|
||||
|
||||
@classmethod
|
||||
def adapt_emulated_to_native(
|
||||
cls, interval: sqltypes.Interval, **kw: Any # type: ignore[override]
|
||||
) -> INTERVAL:
|
||||
return INTERVAL(precision=interval.second_precision)
|
||||
|
||||
@property
|
||||
def _type_affinity(self) -> Type[sqltypes.Interval]:
|
||||
return sqltypes.Interval
|
||||
|
||||
def as_generic(self, allow_nulltype: bool = False) -> sqltypes.Interval:
|
||||
return sqltypes.Interval(native=True, second_precision=self.precision)
|
||||
|
||||
@property
|
||||
def python_type(self) -> Type[dt.timedelta]:
|
||||
return dt.timedelta
|
||||
|
||||
def literal_processor(
|
||||
self, dialect: Dialect
|
||||
) -> Optional[_LiteralProcessorType[dt.timedelta]]:
|
||||
def process(value: dt.timedelta) -> str:
|
||||
return f"make_interval(secs=>{value.total_seconds()})"
|
||||
|
||||
return process
|
||||
|
||||
|
||||
PGInterval = INTERVAL
|
||||
|
||||
|
||||
class BIT(sqltypes.TypeEngine[int]):
|
||||
__visit_name__ = "BIT"
|
||||
|
||||
def __init__(
|
||||
self, length: Optional[int] = None, varying: bool = False
|
||||
) -> None:
|
||||
if varying:
|
||||
# BIT VARYING can be unlimited-length, so no default
|
||||
self.length = length
|
||||
else:
|
||||
# BIT without VARYING defaults to length 1
|
||||
self.length = length or 1
|
||||
self.varying = varying
|
||||
|
||||
|
||||
PGBit = BIT
|
||||
|
||||
|
||||
class TSVECTOR(sqltypes.TypeEngine[str]):
|
||||
"""The :class:`_postgresql.TSVECTOR` type implements the PostgreSQL
|
||||
text search type TSVECTOR.
|
||||
|
||||
It can be used to do full text queries on natural language
|
||||
documents.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`postgresql_match`
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "TSVECTOR"
|
||||
|
||||
|
||||
class CITEXT(sqltypes.TEXT):
|
||||
"""Provide the PostgreSQL CITEXT type.
|
||||
|
||||
.. versionadded:: 2.0.7
|
||||
|
||||
"""
|
||||
|
||||
__visit_name__ = "CITEXT"
|
||||
|
||||
def coerce_compared_value(
|
||||
self, op: Optional[OperatorType], value: Any
|
||||
) -> TypeEngine[Any]:
|
||||
return self
|
|
@ -0,0 +1,57 @@
|
|||
# dialects/sqlite/__init__.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
from . import aiosqlite # noqa
|
||||
from . import base # noqa
|
||||
from . import pysqlcipher # noqa
|
||||
from . import pysqlite # noqa
|
||||
from .base import BLOB
|
||||
from .base import BOOLEAN
|
||||
from .base import CHAR
|
||||
from .base import DATE
|
||||
from .base import DATETIME
|
||||
from .base import DECIMAL
|
||||
from .base import FLOAT
|
||||
from .base import INTEGER
|
||||
from .base import JSON
|
||||
from .base import NUMERIC
|
||||
from .base import REAL
|
||||
from .base import SMALLINT
|
||||
from .base import TEXT
|
||||
from .base import TIME
|
||||
from .base import TIMESTAMP
|
||||
from .base import VARCHAR
|
||||
from .dml import Insert
|
||||
from .dml import insert
|
||||
|
||||
# default dialect
|
||||
base.dialect = dialect = pysqlite.dialect
|
||||
|
||||
|
||||
__all__ = (
|
||||
"BLOB",
|
||||
"BOOLEAN",
|
||||
"CHAR",
|
||||
"DATE",
|
||||
"DATETIME",
|
||||
"DECIMAL",
|
||||
"FLOAT",
|
||||
"INTEGER",
|
||||
"JSON",
|
||||
"NUMERIC",
|
||||
"SMALLINT",
|
||||
"TEXT",
|
||||
"TIME",
|
||||
"TIMESTAMP",
|
||||
"VARCHAR",
|
||||
"REAL",
|
||||
"Insert",
|
||||
"insert",
|
||||
"dialect",
|
||||
)
|
|
@ -0,0 +1,396 @@
|
|||
# dialects/sqlite/aiosqlite.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
|
||||
.. dialect:: sqlite+aiosqlite
|
||||
:name: aiosqlite
|
||||
:dbapi: aiosqlite
|
||||
:connectstring: sqlite+aiosqlite:///file_path
|
||||
:url: https://pypi.org/project/aiosqlite/
|
||||
|
||||
The aiosqlite dialect provides support for the SQLAlchemy asyncio interface
|
||||
running on top of pysqlite.
|
||||
|
||||
aiosqlite is a wrapper around pysqlite that uses a background thread for
|
||||
each connection. It does not actually use non-blocking IO, as SQLite
|
||||
databases are not socket-based. However it does provide a working asyncio
|
||||
interface that's useful for testing and prototyping purposes.
|
||||
|
||||
Using a special asyncio mediation layer, the aiosqlite dialect is usable
|
||||
as the backend for the :ref:`SQLAlchemy asyncio <asyncio_toplevel>`
|
||||
extension package.
|
||||
|
||||
This dialect should normally be used only with the
|
||||
:func:`_asyncio.create_async_engine` engine creation function::
|
||||
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
engine = create_async_engine("sqlite+aiosqlite:///filename")
|
||||
|
||||
The URL passes through all arguments to the ``pysqlite`` driver, so all
|
||||
connection arguments are the same as they are for that of :ref:`pysqlite`.
|
||||
|
||||
.. _aiosqlite_udfs:
|
||||
|
||||
User-Defined Functions
|
||||
----------------------
|
||||
|
||||
aiosqlite extends pysqlite to support async, so we can create our own user-defined functions (UDFs)
|
||||
in Python and use them directly in SQLite queries as described here: :ref:`pysqlite_udfs`.
|
||||
|
||||
.. _aiosqlite_serializable:
|
||||
|
||||
Serializable isolation / Savepoints / Transactional DDL (asyncio version)
|
||||
-------------------------------------------------------------------------
|
||||
|
||||
Similarly to pysqlite, aiosqlite does not support SAVEPOINT feature.
|
||||
|
||||
The solution is similar to :ref:`pysqlite_serializable`. This is achieved by the event listeners in async::
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
engine = create_async_engine("sqlite+aiosqlite:///myfile.db")
|
||||
|
||||
@event.listens_for(engine.sync_engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
# disable aiosqlite's emitting of the BEGIN statement entirely.
|
||||
# also stops it from emitting COMMIT before any DDL.
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine.sync_engine, "begin")
|
||||
def do_begin(conn):
|
||||
# emit our own BEGIN
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
.. warning:: When using the above recipe, it is advised to not use the
|
||||
:paramref:`.Connection.execution_options.isolation_level` setting on
|
||||
:class:`_engine.Connection` and :func:`_sa.create_engine`
|
||||
with the SQLite driver,
|
||||
as this function necessarily will also alter the ".isolation_level" setting.
|
||||
|
||||
""" # noqa
|
||||
|
||||
import asyncio
|
||||
from functools import partial
|
||||
|
||||
from .base import SQLiteExecutionContext
|
||||
from .pysqlite import SQLiteDialect_pysqlite
|
||||
from ... import pool
|
||||
from ... import util
|
||||
from ...engine import AdaptedConnection
|
||||
from ...util.concurrency import await_fallback
|
||||
from ...util.concurrency import await_only
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_cursor:
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
|
||||
__slots__ = (
|
||||
"_adapt_connection",
|
||||
"_connection",
|
||||
"description",
|
||||
"await_",
|
||||
"_rows",
|
||||
"arraysize",
|
||||
"rowcount",
|
||||
"lastrowid",
|
||||
)
|
||||
|
||||
server_side = False
|
||||
|
||||
def __init__(self, adapt_connection):
|
||||
self._adapt_connection = adapt_connection
|
||||
self._connection = adapt_connection._connection
|
||||
self.await_ = adapt_connection.await_
|
||||
self.arraysize = 1
|
||||
self.rowcount = -1
|
||||
self.description = None
|
||||
self._rows = []
|
||||
|
||||
def close(self):
|
||||
self._rows[:] = []
|
||||
|
||||
def execute(self, operation, parameters=None):
|
||||
try:
|
||||
_cursor = self.await_(self._connection.cursor())
|
||||
|
||||
if parameters is None:
|
||||
self.await_(_cursor.execute(operation))
|
||||
else:
|
||||
self.await_(_cursor.execute(operation, parameters))
|
||||
|
||||
if _cursor.description:
|
||||
self.description = _cursor.description
|
||||
self.lastrowid = self.rowcount = -1
|
||||
|
||||
if not self.server_side:
|
||||
self._rows = self.await_(_cursor.fetchall())
|
||||
else:
|
||||
self.description = None
|
||||
self.lastrowid = _cursor.lastrowid
|
||||
self.rowcount = _cursor.rowcount
|
||||
|
||||
if not self.server_side:
|
||||
self.await_(_cursor.close())
|
||||
else:
|
||||
self._cursor = _cursor
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
def executemany(self, operation, seq_of_parameters):
|
||||
try:
|
||||
_cursor = self.await_(self._connection.cursor())
|
||||
self.await_(_cursor.executemany(operation, seq_of_parameters))
|
||||
self.description = None
|
||||
self.lastrowid = _cursor.lastrowid
|
||||
self.rowcount = _cursor.rowcount
|
||||
self.await_(_cursor.close())
|
||||
except Exception as error:
|
||||
self._adapt_connection._handle_exception(error)
|
||||
|
||||
def setinputsizes(self, *inputsizes):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
while self._rows:
|
||||
yield self._rows.pop(0)
|
||||
|
||||
def fetchone(self):
|
||||
if self._rows:
|
||||
return self._rows.pop(0)
|
||||
else:
|
||||
return None
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
||||
retval = self._rows[0:size]
|
||||
self._rows[:] = self._rows[size:]
|
||||
return retval
|
||||
|
||||
def fetchall(self):
|
||||
retval = self._rows[:]
|
||||
self._rows[:] = []
|
||||
return retval
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_ss_cursor(AsyncAdapt_aiosqlite_cursor):
|
||||
# TODO: base on connectors/asyncio.py
|
||||
# see #10415
|
||||
__slots__ = "_cursor"
|
||||
|
||||
server_side = True
|
||||
|
||||
def __init__(self, *arg, **kw):
|
||||
super().__init__(*arg, **kw)
|
||||
self._cursor = None
|
||||
|
||||
def close(self):
|
||||
if self._cursor is not None:
|
||||
self.await_(self._cursor.close())
|
||||
self._cursor = None
|
||||
|
||||
def fetchone(self):
|
||||
return self.await_(self._cursor.fetchone())
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
return self.await_(self._cursor.fetchmany(size=size))
|
||||
|
||||
def fetchall(self):
|
||||
return self.await_(self._cursor.fetchall())
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_connection(AdaptedConnection):
|
||||
await_ = staticmethod(await_only)
|
||||
__slots__ = ("dbapi",)
|
||||
|
||||
def __init__(self, dbapi, connection):
|
||||
self.dbapi = dbapi
|
||||
self._connection = connection
|
||||
|
||||
@property
|
||||
def isolation_level(self):
|
||||
return self._connection.isolation_level
|
||||
|
||||
@isolation_level.setter
|
||||
def isolation_level(self, value):
|
||||
# aiosqlite's isolation_level setter works outside the Thread
|
||||
# that it's supposed to, necessitating setting check_same_thread=False.
|
||||
# for improved stability, we instead invent our own awaitable version
|
||||
# using aiosqlite's async queue directly.
|
||||
|
||||
def set_iso(connection, value):
|
||||
connection.isolation_level = value
|
||||
|
||||
function = partial(set_iso, self._connection._conn, value)
|
||||
future = asyncio.get_event_loop().create_future()
|
||||
|
||||
self._connection._tx.put_nowait((future, function))
|
||||
|
||||
try:
|
||||
return self.await_(future)
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def create_function(self, *args, **kw):
|
||||
try:
|
||||
self.await_(self._connection.create_function(*args, **kw))
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def cursor(self, server_side=False):
|
||||
if server_side:
|
||||
return AsyncAdapt_aiosqlite_ss_cursor(self)
|
||||
else:
|
||||
return AsyncAdapt_aiosqlite_cursor(self)
|
||||
|
||||
def execute(self, *args, **kw):
|
||||
return self.await_(self._connection.execute(*args, **kw))
|
||||
|
||||
def rollback(self):
|
||||
try:
|
||||
self.await_(self._connection.rollback())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def commit(self):
|
||||
try:
|
||||
self.await_(self._connection.commit())
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def close(self):
|
||||
try:
|
||||
self.await_(self._connection.close())
|
||||
except ValueError:
|
||||
# this is undocumented for aiosqlite, that ValueError
|
||||
# was raised if .close() was called more than once, which is
|
||||
# both not customary for DBAPI and is also not a DBAPI.Error
|
||||
# exception. This is now fixed in aiosqlite via my PR
|
||||
# https://github.com/omnilib/aiosqlite/pull/238, so we can be
|
||||
# assured this will not become some other kind of exception,
|
||||
# since it doesn't raise anymore.
|
||||
|
||||
pass
|
||||
except Exception as error:
|
||||
self._handle_exception(error)
|
||||
|
||||
def _handle_exception(self, error):
|
||||
if (
|
||||
isinstance(error, ValueError)
|
||||
and error.args[0] == "no active connection"
|
||||
):
|
||||
raise self.dbapi.sqlite.OperationalError(
|
||||
"no active connection"
|
||||
) from error
|
||||
else:
|
||||
raise error
|
||||
|
||||
|
||||
class AsyncAdaptFallback_aiosqlite_connection(AsyncAdapt_aiosqlite_connection):
|
||||
__slots__ = ()
|
||||
|
||||
await_ = staticmethod(await_fallback)
|
||||
|
||||
|
||||
class AsyncAdapt_aiosqlite_dbapi:
|
||||
def __init__(self, aiosqlite, sqlite):
|
||||
self.aiosqlite = aiosqlite
|
||||
self.sqlite = sqlite
|
||||
self.paramstyle = "qmark"
|
||||
self._init_dbapi_attributes()
|
||||
|
||||
def _init_dbapi_attributes(self):
|
||||
for name in (
|
||||
"DatabaseError",
|
||||
"Error",
|
||||
"IntegrityError",
|
||||
"NotSupportedError",
|
||||
"OperationalError",
|
||||
"ProgrammingError",
|
||||
"sqlite_version",
|
||||
"sqlite_version_info",
|
||||
):
|
||||
setattr(self, name, getattr(self.aiosqlite, name))
|
||||
|
||||
for name in ("PARSE_COLNAMES", "PARSE_DECLTYPES"):
|
||||
setattr(self, name, getattr(self.sqlite, name))
|
||||
|
||||
for name in ("Binary",):
|
||||
setattr(self, name, getattr(self.sqlite, name))
|
||||
|
||||
def connect(self, *arg, **kw):
|
||||
async_fallback = kw.pop("async_fallback", False)
|
||||
|
||||
creator_fn = kw.pop("async_creator_fn", None)
|
||||
if creator_fn:
|
||||
connection = creator_fn(*arg, **kw)
|
||||
else:
|
||||
connection = self.aiosqlite.connect(*arg, **kw)
|
||||
# it's a Thread. you'll thank us later
|
||||
connection.daemon = True
|
||||
|
||||
if util.asbool(async_fallback):
|
||||
return AsyncAdaptFallback_aiosqlite_connection(
|
||||
self,
|
||||
await_fallback(connection),
|
||||
)
|
||||
else:
|
||||
return AsyncAdapt_aiosqlite_connection(
|
||||
self,
|
||||
await_only(connection),
|
||||
)
|
||||
|
||||
|
||||
class SQLiteExecutionContext_aiosqlite(SQLiteExecutionContext):
|
||||
def create_server_side_cursor(self):
|
||||
return self._dbapi_connection.cursor(server_side=True)
|
||||
|
||||
|
||||
class SQLiteDialect_aiosqlite(SQLiteDialect_pysqlite):
|
||||
driver = "aiosqlite"
|
||||
supports_statement_cache = True
|
||||
|
||||
is_async = True
|
||||
|
||||
supports_server_side_cursors = True
|
||||
|
||||
execution_ctx_cls = SQLiteExecutionContext_aiosqlite
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
return AsyncAdapt_aiosqlite_dbapi(
|
||||
__import__("aiosqlite"), __import__("sqlite3")
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
if cls._is_url_file_db(url):
|
||||
return pool.NullPool
|
||||
else:
|
||||
return pool.StaticPool
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
if isinstance(
|
||||
e, self.dbapi.OperationalError
|
||||
) and "no active connection" in str(e):
|
||||
return True
|
||||
|
||||
return super().is_disconnect(e, connection, cursor)
|
||||
|
||||
def get_driver_connection(self, connection):
|
||||
return connection._connection
|
||||
|
||||
|
||||
dialect = SQLiteDialect_aiosqlite
|
File diff suppressed because it is too large
Load diff
|
@ -0,0 +1,240 @@
|
|||
# dialects/sqlite/dml.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from .._typing import _OnConflictIndexElementsT
|
||||
from .._typing import _OnConflictIndexWhereT
|
||||
from .._typing import _OnConflictSetT
|
||||
from .._typing import _OnConflictWhereT
|
||||
from ... import util
|
||||
from ...sql import coercions
|
||||
from ...sql import roles
|
||||
from ...sql._typing import _DMLTableArgument
|
||||
from ...sql.base import _exclusive_against
|
||||
from ...sql.base import _generative
|
||||
from ...sql.base import ColumnCollection
|
||||
from ...sql.base import ReadOnlyColumnCollection
|
||||
from ...sql.dml import Insert as StandardInsert
|
||||
from ...sql.elements import ClauseElement
|
||||
from ...sql.elements import KeyedColumnElement
|
||||
from ...sql.expression import alias
|
||||
from ...util.typing import Self
|
||||
|
||||
__all__ = ("Insert", "insert")
|
||||
|
||||
|
||||
def insert(table: _DMLTableArgument) -> Insert:
|
||||
"""Construct a sqlite-specific variant :class:`_sqlite.Insert`
|
||||
construct.
|
||||
|
||||
.. container:: inherited_member
|
||||
|
||||
The :func:`sqlalchemy.dialects.sqlite.insert` function creates
|
||||
a :class:`sqlalchemy.dialects.sqlite.Insert`. This class is based
|
||||
on the dialect-agnostic :class:`_sql.Insert` construct which may
|
||||
be constructed using the :func:`_sql.insert` function in
|
||||
SQLAlchemy Core.
|
||||
|
||||
The :class:`_sqlite.Insert` construct includes additional methods
|
||||
:meth:`_sqlite.Insert.on_conflict_do_update`,
|
||||
:meth:`_sqlite.Insert.on_conflict_do_nothing`.
|
||||
|
||||
"""
|
||||
return Insert(table)
|
||||
|
||||
|
||||
class Insert(StandardInsert):
|
||||
"""SQLite-specific implementation of INSERT.
|
||||
|
||||
Adds methods for SQLite-specific syntaxes such as ON CONFLICT.
|
||||
|
||||
The :class:`_sqlite.Insert` object is created using the
|
||||
:func:`sqlalchemy.dialects.sqlite.insert` function.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
.. seealso::
|
||||
|
||||
:ref:`sqlite_on_conflict_insert`
|
||||
|
||||
"""
|
||||
|
||||
stringify_dialect = "sqlite"
|
||||
inherit_cache = False
|
||||
|
||||
@util.memoized_property
|
||||
def excluded(
|
||||
self,
|
||||
) -> ReadOnlyColumnCollection[str, KeyedColumnElement[Any]]:
|
||||
"""Provide the ``excluded`` namespace for an ON CONFLICT statement
|
||||
|
||||
SQLite's ON CONFLICT clause allows reference to the row that would
|
||||
be inserted, known as ``excluded``. This attribute provides
|
||||
all columns in this row to be referenceable.
|
||||
|
||||
.. tip:: The :attr:`_sqlite.Insert.excluded` attribute is an instance
|
||||
of :class:`_expression.ColumnCollection`, which provides an
|
||||
interface the same as that of the :attr:`_schema.Table.c`
|
||||
collection described at :ref:`metadata_tables_and_columns`.
|
||||
With this collection, ordinary names are accessible like attributes
|
||||
(e.g. ``stmt.excluded.some_column``), but special names and
|
||||
dictionary method names should be accessed using indexed access,
|
||||
such as ``stmt.excluded["column name"]`` or
|
||||
``stmt.excluded["values"]``. See the docstring for
|
||||
:class:`_expression.ColumnCollection` for further examples.
|
||||
|
||||
"""
|
||||
return alias(self.table, name="excluded").columns
|
||||
|
||||
_on_conflict_exclusive = _exclusive_against(
|
||||
"_post_values_clause",
|
||||
msgs={
|
||||
"_post_values_clause": "This Insert construct already has "
|
||||
"an ON CONFLICT clause established"
|
||||
},
|
||||
)
|
||||
|
||||
@_generative
|
||||
@_on_conflict_exclusive
|
||||
def on_conflict_do_update(
|
||||
self,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
set_: _OnConflictSetT = None,
|
||||
where: _OnConflictWhereT = None,
|
||||
) -> Self:
|
||||
r"""
|
||||
Specifies a DO UPDATE SET action for ON CONFLICT clause.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`_schema.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index or unique constraint.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
:param set\_:
|
||||
A dictionary or other mapping object
|
||||
where the keys are either names of columns in the target table,
|
||||
or :class:`_schema.Column` objects or other ORM-mapped columns
|
||||
matching that of the target table, and expressions or literals
|
||||
as values, specifying the ``SET`` actions to take.
|
||||
|
||||
.. versionadded:: 1.4 The
|
||||
:paramref:`_sqlite.Insert.on_conflict_do_update.set_`
|
||||
parameter supports :class:`_schema.Column` objects from the target
|
||||
:class:`_schema.Table` as keys.
|
||||
|
||||
.. warning:: This dictionary does **not** take into account
|
||||
Python-specified default UPDATE values or generation functions,
|
||||
e.g. those specified using :paramref:`_schema.Column.onupdate`.
|
||||
These values will not be exercised for an ON CONFLICT style of
|
||||
UPDATE, unless they are manually specified in the
|
||||
:paramref:`.Insert.on_conflict_do_update.set_` dictionary.
|
||||
|
||||
:param where:
|
||||
Optional argument. If present, can be a literal SQL
|
||||
string or an acceptable expression for a ``WHERE`` clause
|
||||
that restricts the rows affected by ``DO UPDATE SET``. Rows
|
||||
not meeting the ``WHERE`` condition will not be updated
|
||||
(effectively a ``DO NOTHING`` for those rows).
|
||||
|
||||
"""
|
||||
|
||||
self._post_values_clause = OnConflictDoUpdate(
|
||||
index_elements, index_where, set_, where
|
||||
)
|
||||
return self
|
||||
|
||||
@_generative
|
||||
@_on_conflict_exclusive
|
||||
def on_conflict_do_nothing(
|
||||
self,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
) -> Self:
|
||||
"""
|
||||
Specifies a DO NOTHING action for ON CONFLICT clause.
|
||||
|
||||
:param index_elements:
|
||||
A sequence consisting of string column names, :class:`_schema.Column`
|
||||
objects, or other column expression objects that will be used
|
||||
to infer a target index or unique constraint.
|
||||
|
||||
:param index_where:
|
||||
Additional WHERE criterion that can be used to infer a
|
||||
conditional target index.
|
||||
|
||||
"""
|
||||
|
||||
self._post_values_clause = OnConflictDoNothing(
|
||||
index_elements, index_where
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
class OnConflictClause(ClauseElement):
|
||||
stringify_dialect = "sqlite"
|
||||
|
||||
constraint_target: None
|
||||
inferred_target_elements: _OnConflictIndexElementsT
|
||||
inferred_target_whereclause: _OnConflictIndexWhereT
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
):
|
||||
if index_elements is not None:
|
||||
self.constraint_target = None
|
||||
self.inferred_target_elements = index_elements
|
||||
self.inferred_target_whereclause = index_where
|
||||
else:
|
||||
self.constraint_target = self.inferred_target_elements = (
|
||||
self.inferred_target_whereclause
|
||||
) = None
|
||||
|
||||
|
||||
class OnConflictDoNothing(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_nothing"
|
||||
|
||||
|
||||
class OnConflictDoUpdate(OnConflictClause):
|
||||
__visit_name__ = "on_conflict_do_update"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
index_elements: _OnConflictIndexElementsT = None,
|
||||
index_where: _OnConflictIndexWhereT = None,
|
||||
set_: _OnConflictSetT = None,
|
||||
where: _OnConflictWhereT = None,
|
||||
):
|
||||
super().__init__(
|
||||
index_elements=index_elements,
|
||||
index_where=index_where,
|
||||
)
|
||||
|
||||
if isinstance(set_, dict):
|
||||
if not set_:
|
||||
raise ValueError("set parameter dictionary must not be empty")
|
||||
elif isinstance(set_, ColumnCollection):
|
||||
set_ = dict(set_)
|
||||
else:
|
||||
raise ValueError(
|
||||
"set parameter must be a non-empty dictionary "
|
||||
"or a ColumnCollection such as the `.c.` collection "
|
||||
"of a Table object"
|
||||
)
|
||||
self.update_values_to_set = [
|
||||
(coercions.expect(roles.DMLColumnRole, key), value)
|
||||
for key, value in set_.items()
|
||||
]
|
||||
self.update_whereclause = where
|
|
@ -0,0 +1,92 @@
|
|||
# dialects/sqlite/json.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
from ... import types as sqltypes
|
||||
|
||||
|
||||
class JSON(sqltypes.JSON):
|
||||
"""SQLite JSON type.
|
||||
|
||||
SQLite supports JSON as of version 3.9 through its JSON1_ extension. Note
|
||||
that JSON1_ is a
|
||||
`loadable extension <https://www.sqlite.org/loadext.html>`_ and as such
|
||||
may not be available, or may require run-time loading.
|
||||
|
||||
:class:`_sqlite.JSON` is used automatically whenever the base
|
||||
:class:`_types.JSON` datatype is used against a SQLite backend.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:class:`_types.JSON` - main documentation for the generic
|
||||
cross-platform JSON datatype.
|
||||
|
||||
The :class:`_sqlite.JSON` type supports persistence of JSON values
|
||||
as well as the core index operations provided by :class:`_types.JSON`
|
||||
datatype, by adapting the operations to render the ``JSON_EXTRACT``
|
||||
function wrapped in the ``JSON_QUOTE`` function at the database level.
|
||||
Extracted values are quoted in order to ensure that the results are
|
||||
always JSON string values.
|
||||
|
||||
|
||||
.. versionadded:: 1.3
|
||||
|
||||
|
||||
.. _JSON1: https://www.sqlite.org/json1.html
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# Note: these objects currently match exactly those of MySQL, however since
|
||||
# these are not generalizable to all JSON implementations, remain separately
|
||||
# implemented for each dialect.
|
||||
class _FormatTypeMixin:
|
||||
def _format_value(self, value):
|
||||
raise NotImplementedError()
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
super_proc = self.string_bind_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
super_proc = self.string_literal_processor(dialect)
|
||||
|
||||
def process(value):
|
||||
value = self._format_value(value)
|
||||
if super_proc:
|
||||
value = super_proc(value)
|
||||
return value
|
||||
|
||||
return process
|
||||
|
||||
|
||||
class JSONIndexType(_FormatTypeMixin, sqltypes.JSON.JSONIndexType):
|
||||
def _format_value(self, value):
|
||||
if isinstance(value, int):
|
||||
value = "$[%s]" % value
|
||||
else:
|
||||
value = '$."%s"' % value
|
||||
return value
|
||||
|
||||
|
||||
class JSONPathType(_FormatTypeMixin, sqltypes.JSON.JSONPathType):
|
||||
def _format_value(self, value):
|
||||
return "$%s" % (
|
||||
"".join(
|
||||
[
|
||||
"[%s]" % elem if isinstance(elem, int) else '."%s"' % elem
|
||||
for elem in value
|
||||
]
|
||||
)
|
||||
)
|
|
@ -0,0 +1,198 @@
|
|||
# dialects/sqlite/provision.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
import os
|
||||
import re
|
||||
|
||||
from ... import exc
|
||||
from ...engine import url as sa_url
|
||||
from ...testing.provision import create_db
|
||||
from ...testing.provision import drop_db
|
||||
from ...testing.provision import follower_url_from_main
|
||||
from ...testing.provision import generate_driver_url
|
||||
from ...testing.provision import log
|
||||
from ...testing.provision import post_configure_engine
|
||||
from ...testing.provision import run_reap_dbs
|
||||
from ...testing.provision import stop_test_class_outside_fixtures
|
||||
from ...testing.provision import temp_table_keyword_args
|
||||
from ...testing.provision import upsert
|
||||
|
||||
|
||||
# TODO: I can't get this to build dynamically with pytest-xdist procs
|
||||
_drivernames = {
|
||||
"pysqlite",
|
||||
"aiosqlite",
|
||||
"pysqlcipher",
|
||||
"pysqlite_numeric",
|
||||
"pysqlite_dollar",
|
||||
}
|
||||
|
||||
|
||||
def _format_url(url, driver, ident):
|
||||
"""given a sqlite url + desired driver + ident, make a canonical
|
||||
URL out of it
|
||||
|
||||
"""
|
||||
url = sa_url.make_url(url)
|
||||
|
||||
if driver is None:
|
||||
driver = url.get_driver_name()
|
||||
|
||||
filename = url.database
|
||||
|
||||
needs_enc = driver == "pysqlcipher"
|
||||
name_token = None
|
||||
|
||||
if filename and filename != ":memory:":
|
||||
assert "test_schema" not in filename
|
||||
tokens = re.split(r"[_\.]", filename)
|
||||
|
||||
new_filename = f"{driver}"
|
||||
|
||||
for token in tokens:
|
||||
if token in _drivernames:
|
||||
if driver is None:
|
||||
driver = token
|
||||
continue
|
||||
elif token in ("db", "enc"):
|
||||
continue
|
||||
elif name_token is None:
|
||||
name_token = token.strip("_")
|
||||
|
||||
assert name_token, f"sqlite filename has no name token: {url.database}"
|
||||
|
||||
new_filename = f"{name_token}_{driver}"
|
||||
if ident:
|
||||
new_filename += f"_{ident}"
|
||||
new_filename += ".db"
|
||||
if needs_enc:
|
||||
new_filename += ".enc"
|
||||
url = url.set(database=new_filename)
|
||||
|
||||
if needs_enc:
|
||||
url = url.set(password="test")
|
||||
|
||||
url = url.set(drivername="sqlite+%s" % (driver,))
|
||||
|
||||
return url
|
||||
|
||||
|
||||
@generate_driver_url.for_db("sqlite")
|
||||
def generate_driver_url(url, driver, query_str):
|
||||
url = _format_url(url, driver, None)
|
||||
|
||||
try:
|
||||
url.get_dialect()
|
||||
except exc.NoSuchModuleError:
|
||||
return None
|
||||
else:
|
||||
return url
|
||||
|
||||
|
||||
@follower_url_from_main.for_db("sqlite")
|
||||
def _sqlite_follower_url_from_main(url, ident):
|
||||
return _format_url(url, None, ident)
|
||||
|
||||
|
||||
@post_configure_engine.for_db("sqlite")
|
||||
def _sqlite_post_configure_engine(url, engine, follower_ident):
|
||||
from sqlalchemy import event
|
||||
|
||||
if follower_ident:
|
||||
attach_path = f"{follower_ident}_{engine.driver}_test_schema.db"
|
||||
else:
|
||||
attach_path = f"{engine.driver}_test_schema.db"
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def connect(dbapi_connection, connection_record):
|
||||
# use file DBs in all cases, memory acts kind of strangely
|
||||
# as an attached
|
||||
|
||||
# NOTE! this has to be done *per connection*. New sqlite connection,
|
||||
# as we get with say, QueuePool, the attaches are gone.
|
||||
# so schemes to delete those attached files have to be done at the
|
||||
# filesystem level and not rely upon what attachments are in a
|
||||
# particular SQLite connection
|
||||
dbapi_connection.execute(
|
||||
f'ATTACH DATABASE "{attach_path}" AS test_schema'
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "engine_disposed")
|
||||
def dispose(engine):
|
||||
"""most databases should be dropped using
|
||||
stop_test_class_outside_fixtures
|
||||
|
||||
however a few tests like AttachedDBTest might not get triggered on
|
||||
that main hook
|
||||
|
||||
"""
|
||||
|
||||
if os.path.exists(attach_path):
|
||||
os.remove(attach_path)
|
||||
|
||||
filename = engine.url.database
|
||||
|
||||
if filename and filename != ":memory:" and os.path.exists(filename):
|
||||
os.remove(filename)
|
||||
|
||||
|
||||
@create_db.for_db("sqlite")
|
||||
def _sqlite_create_db(cfg, eng, ident):
|
||||
pass
|
||||
|
||||
|
||||
@drop_db.for_db("sqlite")
|
||||
def _sqlite_drop_db(cfg, eng, ident):
|
||||
_drop_dbs_w_ident(eng.url.database, eng.driver, ident)
|
||||
|
||||
|
||||
def _drop_dbs_w_ident(databasename, driver, ident):
|
||||
for path in os.listdir("."):
|
||||
fname, ext = os.path.split(path)
|
||||
if ident in fname and ext in [".db", ".db.enc"]:
|
||||
log.info("deleting SQLite database file: %s", path)
|
||||
os.remove(path)
|
||||
|
||||
|
||||
@stop_test_class_outside_fixtures.for_db("sqlite")
|
||||
def stop_test_class_outside_fixtures(config, db, cls):
|
||||
db.dispose()
|
||||
|
||||
|
||||
@temp_table_keyword_args.for_db("sqlite")
|
||||
def _sqlite_temp_table_keyword_args(cfg, eng):
|
||||
return {"prefixes": ["TEMPORARY"]}
|
||||
|
||||
|
||||
@run_reap_dbs.for_db("sqlite")
|
||||
def _reap_sqlite_dbs(url, idents):
|
||||
log.info("db reaper connecting to %r", url)
|
||||
log.info("identifiers in file: %s", ", ".join(idents))
|
||||
url = sa_url.make_url(url)
|
||||
for ident in idents:
|
||||
for drivername in _drivernames:
|
||||
_drop_dbs_w_ident(url.database, drivername, ident)
|
||||
|
||||
|
||||
@upsert.for_db("sqlite")
|
||||
def _upsert(
|
||||
cfg, table, returning, *, set_lambda=None, sort_by_parameter_order=False
|
||||
):
|
||||
from sqlalchemy.dialects.sqlite import insert
|
||||
|
||||
stmt = insert(table)
|
||||
|
||||
if set_lambda:
|
||||
stmt = stmt.on_conflict_do_update(set_=set_lambda(stmt.excluded))
|
||||
else:
|
||||
stmt = stmt.on_conflict_do_nothing()
|
||||
|
||||
stmt = stmt.returning(
|
||||
*returning, sort_by_parameter_order=sort_by_parameter_order
|
||||
)
|
||||
return stmt
|
|
@ -0,0 +1,155 @@
|
|||
# dialects/sqlite/pysqlcipher.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
"""
|
||||
.. dialect:: sqlite+pysqlcipher
|
||||
:name: pysqlcipher
|
||||
:dbapi: sqlcipher 3 or pysqlcipher
|
||||
:connectstring: sqlite+pysqlcipher://:passphrase@/file_path[?kdf_iter=<iter>]
|
||||
|
||||
Dialect for support of DBAPIs that make use of the
|
||||
`SQLCipher <https://www.zetetic.net/sqlcipher>`_ backend.
|
||||
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
Current dialect selection logic is:
|
||||
|
||||
* If the :paramref:`_sa.create_engine.module` parameter supplies a DBAPI module,
|
||||
that module is used.
|
||||
* Otherwise for Python 3, choose https://pypi.org/project/sqlcipher3/
|
||||
* If not available, fall back to https://pypi.org/project/pysqlcipher3/
|
||||
* For Python 2, https://pypi.org/project/pysqlcipher/ is used.
|
||||
|
||||
.. warning:: The ``pysqlcipher3`` and ``pysqlcipher`` DBAPI drivers are no
|
||||
longer maintained; the ``sqlcipher3`` driver as of this writing appears
|
||||
to be current. For future compatibility, any pysqlcipher-compatible DBAPI
|
||||
may be used as follows::
|
||||
|
||||
import sqlcipher_compatible_driver
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
e = create_engine(
|
||||
"sqlite+pysqlcipher://:password@/dbname.db",
|
||||
module=sqlcipher_compatible_driver
|
||||
)
|
||||
|
||||
These drivers make use of the SQLCipher engine. This system essentially
|
||||
introduces new PRAGMA commands to SQLite which allows the setting of a
|
||||
passphrase and other encryption parameters, allowing the database file to be
|
||||
encrypted.
|
||||
|
||||
|
||||
Connect Strings
|
||||
---------------
|
||||
|
||||
The format of the connect string is in every way the same as that
|
||||
of the :mod:`~sqlalchemy.dialects.sqlite.pysqlite` driver, except that the
|
||||
"password" field is now accepted, which should contain a passphrase::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db')
|
||||
|
||||
For an absolute file path, two leading slashes should be used for the
|
||||
database name::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@//path/to/foo.db')
|
||||
|
||||
A selection of additional encryption-related pragmas supported by SQLCipher
|
||||
as documented at https://www.zetetic.net/sqlcipher/sqlcipher-api/ can be passed
|
||||
in the query string, and will result in that PRAGMA being called for each
|
||||
new connection. Currently, ``cipher``, ``kdf_iter``
|
||||
``cipher_page_size`` and ``cipher_use_hmac`` are supported::
|
||||
|
||||
e = create_engine('sqlite+pysqlcipher://:testing@/foo.db?cipher=aes-256-cfb&kdf_iter=64000')
|
||||
|
||||
.. warning:: Previous versions of sqlalchemy did not take into consideration
|
||||
the encryption-related pragmas passed in the url string, that were silently
|
||||
ignored. This may cause errors when opening files saved by a
|
||||
previous sqlalchemy version if the encryption options do not match.
|
||||
|
||||
|
||||
Pooling Behavior
|
||||
----------------
|
||||
|
||||
The driver makes a change to the default pool behavior of pysqlite
|
||||
as described in :ref:`pysqlite_threading_pooling`. The pysqlcipher driver
|
||||
has been observed to be significantly slower on connection than the
|
||||
pysqlite driver, most likely due to the encryption overhead, so the
|
||||
dialect here defaults to using the :class:`.SingletonThreadPool`
|
||||
implementation,
|
||||
instead of the :class:`.NullPool` pool used by pysqlite. As always, the pool
|
||||
implementation is entirely configurable using the
|
||||
:paramref:`_sa.create_engine.poolclass` parameter; the :class:`.
|
||||
StaticPool` may
|
||||
be more feasible for single-threaded use, or :class:`.NullPool` may be used
|
||||
to prevent unencrypted connections from being held open for long periods of
|
||||
time, at the expense of slower startup time for new connections.
|
||||
|
||||
|
||||
""" # noqa
|
||||
|
||||
from .pysqlite import SQLiteDialect_pysqlite
|
||||
from ... import pool
|
||||
|
||||
|
||||
class SQLiteDialect_pysqlcipher(SQLiteDialect_pysqlite):
|
||||
driver = "pysqlcipher"
|
||||
supports_statement_cache = True
|
||||
|
||||
pragmas = ("kdf_iter", "cipher", "cipher_page_size", "cipher_use_hmac")
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
try:
|
||||
import sqlcipher3 as sqlcipher
|
||||
except ImportError:
|
||||
pass
|
||||
else:
|
||||
return sqlcipher
|
||||
|
||||
from pysqlcipher3 import dbapi2 as sqlcipher
|
||||
|
||||
return sqlcipher
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
return pool.SingletonThreadPool
|
||||
|
||||
def on_connect_url(self, url):
|
||||
super_on_connect = super().on_connect_url(url)
|
||||
|
||||
# pull the info we need from the URL early. Even though URL
|
||||
# is immutable, we don't want any in-place changes to the URL
|
||||
# to affect things
|
||||
passphrase = url.password or ""
|
||||
url_query = dict(url.query)
|
||||
|
||||
def on_connect(conn):
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('pragma key="%s"' % passphrase)
|
||||
for prag in self.pragmas:
|
||||
value = url_query.get(prag, None)
|
||||
if value is not None:
|
||||
cursor.execute('pragma %s="%s"' % (prag, value))
|
||||
cursor.close()
|
||||
|
||||
if super_on_connect:
|
||||
super_on_connect(conn)
|
||||
|
||||
return on_connect
|
||||
|
||||
def create_connect_args(self, url):
|
||||
plain_url = url._replace(password=None)
|
||||
plain_url = plain_url.difference_update_query(self.pragmas)
|
||||
return super().create_connect_args(plain_url)
|
||||
|
||||
|
||||
dialect = SQLiteDialect_pysqlcipher
|
|
@ -0,0 +1,753 @@
|
|||
# dialects/sqlite/pysqlite.py
|
||||
# Copyright (C) 2005-2024 the SQLAlchemy authors and contributors
|
||||
# <see AUTHORS file>
|
||||
#
|
||||
# This module is part of SQLAlchemy and is released under
|
||||
# the MIT License: https://www.opensource.org/licenses/mit-license.php
|
||||
# mypy: ignore-errors
|
||||
|
||||
|
||||
r"""
|
||||
.. dialect:: sqlite+pysqlite
|
||||
:name: pysqlite
|
||||
:dbapi: sqlite3
|
||||
:connectstring: sqlite+pysqlite:///file_path
|
||||
:url: https://docs.python.org/library/sqlite3.html
|
||||
|
||||
Note that ``pysqlite`` is the same driver as the ``sqlite3``
|
||||
module included with the Python distribution.
|
||||
|
||||
Driver
|
||||
------
|
||||
|
||||
The ``sqlite3`` Python DBAPI is standard on all modern Python versions;
|
||||
for cPython and Pypy, no additional installation is necessary.
|
||||
|
||||
|
||||
Connect Strings
|
||||
---------------
|
||||
|
||||
The file specification for the SQLite database is taken as the "database"
|
||||
portion of the URL. Note that the format of a SQLAlchemy url is::
|
||||
|
||||
driver://user:pass@host/database
|
||||
|
||||
This means that the actual filename to be used starts with the characters to
|
||||
the **right** of the third slash. So connecting to a relative filepath
|
||||
looks like::
|
||||
|
||||
# relative path
|
||||
e = create_engine('sqlite:///path/to/database.db')
|
||||
|
||||
An absolute path, which is denoted by starting with a slash, means you
|
||||
need **four** slashes::
|
||||
|
||||
# absolute path
|
||||
e = create_engine('sqlite:////path/to/database.db')
|
||||
|
||||
To use a Windows path, regular drive specifications and backslashes can be
|
||||
used. Double backslashes are probably needed::
|
||||
|
||||
# absolute path on Windows
|
||||
e = create_engine('sqlite:///C:\\path\\to\\database.db')
|
||||
|
||||
The sqlite ``:memory:`` identifier is the default if no filepath is
|
||||
present. Specify ``sqlite://`` and nothing else::
|
||||
|
||||
# in-memory database
|
||||
e = create_engine('sqlite://')
|
||||
|
||||
.. _pysqlite_uri_connections:
|
||||
|
||||
URI Connections
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
Modern versions of SQLite support an alternative system of connecting using a
|
||||
`driver level URI <https://www.sqlite.org/uri.html>`_, which has the advantage
|
||||
that additional driver-level arguments can be passed including options such as
|
||||
"read only". The Python sqlite3 driver supports this mode under modern Python
|
||||
3 versions. The SQLAlchemy pysqlite driver supports this mode of use by
|
||||
specifying "uri=true" in the URL query string. The SQLite-level "URI" is kept
|
||||
as the "database" portion of the SQLAlchemy url (that is, following a slash)::
|
||||
|
||||
e = create_engine("sqlite:///file:path/to/database?mode=ro&uri=true")
|
||||
|
||||
.. note:: The "uri=true" parameter must appear in the **query string**
|
||||
of the URL. It will not currently work as expected if it is only
|
||||
present in the :paramref:`_sa.create_engine.connect_args`
|
||||
parameter dictionary.
|
||||
|
||||
The logic reconciles the simultaneous presence of SQLAlchemy's query string and
|
||||
SQLite's query string by separating out the parameters that belong to the
|
||||
Python sqlite3 driver vs. those that belong to the SQLite URI. This is
|
||||
achieved through the use of a fixed list of parameters known to be accepted by
|
||||
the Python side of the driver. For example, to include a URL that indicates
|
||||
the Python sqlite3 "timeout" and "check_same_thread" parameters, along with the
|
||||
SQLite "mode" and "nolock" parameters, they can all be passed together on the
|
||||
query string::
|
||||
|
||||
e = create_engine(
|
||||
"sqlite:///file:path/to/database?"
|
||||
"check_same_thread=true&timeout=10&mode=ro&nolock=1&uri=true"
|
||||
)
|
||||
|
||||
Above, the pysqlite / sqlite3 DBAPI would be passed arguments as::
|
||||
|
||||
sqlite3.connect(
|
||||
"file:path/to/database?mode=ro&nolock=1",
|
||||
check_same_thread=True, timeout=10, uri=True
|
||||
)
|
||||
|
||||
Regarding future parameters added to either the Python or native drivers. new
|
||||
parameter names added to the SQLite URI scheme should be automatically
|
||||
accommodated by this scheme. New parameter names added to the Python driver
|
||||
side can be accommodated by specifying them in the
|
||||
:paramref:`_sa.create_engine.connect_args` dictionary,
|
||||
until dialect support is
|
||||
added by SQLAlchemy. For the less likely case that the native SQLite driver
|
||||
adds a new parameter name that overlaps with one of the existing, known Python
|
||||
driver parameters (such as "timeout" perhaps), SQLAlchemy's dialect would
|
||||
require adjustment for the URL scheme to continue to support this.
|
||||
|
||||
As is always the case for all SQLAlchemy dialects, the entire "URL" process
|
||||
can be bypassed in :func:`_sa.create_engine` through the use of the
|
||||
:paramref:`_sa.create_engine.creator`
|
||||
parameter which allows for a custom callable
|
||||
that creates a Python sqlite3 driver level connection directly.
|
||||
|
||||
.. versionadded:: 1.3.9
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Uniform Resource Identifiers <https://www.sqlite.org/uri.html>`_ - in
|
||||
the SQLite documentation
|
||||
|
||||
.. _pysqlite_regexp:
|
||||
|
||||
Regular Expression Support
|
||||
---------------------------
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
Support for the :meth:`_sql.ColumnOperators.regexp_match` operator is provided
|
||||
using Python's re.search_ function. SQLite itself does not include a working
|
||||
regular expression operator; instead, it includes a non-implemented placeholder
|
||||
operator ``REGEXP`` that calls a user-defined function that must be provided.
|
||||
|
||||
SQLAlchemy's implementation makes use of the pysqlite create_function_ hook
|
||||
as follows::
|
||||
|
||||
|
||||
def regexp(a, b):
|
||||
return re.search(a, b) is not None
|
||||
|
||||
sqlite_connection.create_function(
|
||||
"regexp", 2, regexp,
|
||||
)
|
||||
|
||||
There is currently no support for regular expression flags as a separate
|
||||
argument, as these are not supported by SQLite's REGEXP operator, however these
|
||||
may be included inline within the regular expression string. See `Python regular expressions`_ for
|
||||
details.
|
||||
|
||||
.. seealso::
|
||||
|
||||
`Python regular expressions`_: Documentation for Python's regular expression syntax.
|
||||
|
||||
.. _create_function: https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
|
||||
|
||||
.. _re.search: https://docs.python.org/3/library/re.html#re.search
|
||||
|
||||
.. _Python regular expressions: https://docs.python.org/3/library/re.html#re.search
|
||||
|
||||
|
||||
|
||||
Compatibility with sqlite3 "native" date and datetime types
|
||||
-----------------------------------------------------------
|
||||
|
||||
The pysqlite driver includes the sqlite3.PARSE_DECLTYPES and
|
||||
sqlite3.PARSE_COLNAMES options, which have the effect of any column
|
||||
or expression explicitly cast as "date" or "timestamp" will be converted
|
||||
to a Python date or datetime object. The date and datetime types provided
|
||||
with the pysqlite dialect are not currently compatible with these options,
|
||||
since they render the ISO date/datetime including microseconds, which
|
||||
pysqlite's driver does not. Additionally, SQLAlchemy does not at
|
||||
this time automatically render the "cast" syntax required for the
|
||||
freestanding functions "current_timestamp" and "current_date" to return
|
||||
datetime/date types natively. Unfortunately, pysqlite
|
||||
does not provide the standard DBAPI types in ``cursor.description``,
|
||||
leaving SQLAlchemy with no way to detect these types on the fly
|
||||
without expensive per-row type checks.
|
||||
|
||||
Keeping in mind that pysqlite's parsing option is not recommended,
|
||||
nor should be necessary, for use with SQLAlchemy, usage of PARSE_DECLTYPES
|
||||
can be forced if one configures "native_datetime=True" on create_engine()::
|
||||
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'detect_types':
|
||||
sqlite3.PARSE_DECLTYPES|sqlite3.PARSE_COLNAMES},
|
||||
native_datetime=True
|
||||
)
|
||||
|
||||
With this flag enabled, the DATE and TIMESTAMP types (but note - not the
|
||||
DATETIME or TIME types...confused yet ?) will not perform any bind parameter
|
||||
or result processing. Execution of "func.current_date()" will return a string.
|
||||
"func.current_timestamp()" is registered as returning a DATETIME type in
|
||||
SQLAlchemy, so this function still receives SQLAlchemy-level result
|
||||
processing.
|
||||
|
||||
.. _pysqlite_threading_pooling:
|
||||
|
||||
Threading/Pooling Behavior
|
||||
---------------------------
|
||||
|
||||
The ``sqlite3`` DBAPI by default prohibits the use of a particular connection
|
||||
in a thread which is not the one in which it was created. As SQLite has
|
||||
matured, it's behavior under multiple threads has improved, and even includes
|
||||
options for memory only databases to be used in multiple threads.
|
||||
|
||||
The thread prohibition is known as "check same thread" and may be controlled
|
||||
using the ``sqlite3`` parameter ``check_same_thread``, which will disable or
|
||||
enable this check. SQLAlchemy's default behavior here is to set
|
||||
``check_same_thread`` to ``False`` automatically whenever a file-based database
|
||||
is in use, to establish compatibility with the default pool class
|
||||
:class:`.QueuePool`.
|
||||
|
||||
The SQLAlchemy ``pysqlite`` DBAPI establishes the connection pool differently
|
||||
based on the kind of SQLite database that's requested:
|
||||
|
||||
* When a ``:memory:`` SQLite database is specified, the dialect by default
|
||||
will use :class:`.SingletonThreadPool`. This pool maintains a single
|
||||
connection per thread, so that all access to the engine within the current
|
||||
thread use the same ``:memory:`` database - other threads would access a
|
||||
different ``:memory:`` database. The ``check_same_thread`` parameter
|
||||
defaults to ``True``.
|
||||
* When a file-based database is specified, the dialect will use
|
||||
:class:`.QueuePool` as the source of connections. at the same time,
|
||||
the ``check_same_thread`` flag is set to False by default unless overridden.
|
||||
|
||||
.. versionchanged:: 2.0
|
||||
|
||||
SQLite file database engines now use :class:`.QueuePool` by default.
|
||||
Previously, :class:`.NullPool` were used. The :class:`.NullPool` class
|
||||
may be used by specifying it via the
|
||||
:paramref:`_sa.create_engine.poolclass` parameter.
|
||||
|
||||
Disabling Connection Pooling for File Databases
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Pooling may be disabled for a file based database by specifying the
|
||||
:class:`.NullPool` implementation for the :func:`_sa.create_engine.poolclass`
|
||||
parameter::
|
||||
|
||||
from sqlalchemy import NullPool
|
||||
engine = create_engine("sqlite:///myfile.db", poolclass=NullPool)
|
||||
|
||||
It's been observed that the :class:`.NullPool` implementation incurs an
|
||||
extremely small performance overhead for repeated checkouts due to the lack of
|
||||
connection re-use implemented by :class:`.QueuePool`. However, it still
|
||||
may be beneficial to use this class if the application is experiencing
|
||||
issues with files being locked.
|
||||
|
||||
Using a Memory Database in Multiple Threads
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
To use a ``:memory:`` database in a multithreaded scenario, the same
|
||||
connection object must be shared among threads, since the database exists
|
||||
only within the scope of that connection. The
|
||||
:class:`.StaticPool` implementation will maintain a single connection
|
||||
globally, and the ``check_same_thread`` flag can be passed to Pysqlite
|
||||
as ``False``::
|
||||
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine('sqlite://',
|
||||
connect_args={'check_same_thread':False},
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that using a ``:memory:`` database in multiple threads requires a recent
|
||||
version of SQLite.
|
||||
|
||||
Using Temporary Tables with SQLite
|
||||
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
Due to the way SQLite deals with temporary tables, if you wish to use a
|
||||
temporary table in a file-based SQLite database across multiple checkouts
|
||||
from the connection pool, such as when using an ORM :class:`.Session` where
|
||||
the temporary table should continue to remain after :meth:`.Session.commit` or
|
||||
:meth:`.Session.rollback` is called, a pool which maintains a single
|
||||
connection must be used. Use :class:`.SingletonThreadPool` if the scope is
|
||||
only needed within the current thread, or :class:`.StaticPool` is scope is
|
||||
needed within multiple threads for this case::
|
||||
|
||||
# maintain the same connection per thread
|
||||
from sqlalchemy.pool import SingletonThreadPool
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=SingletonThreadPool)
|
||||
|
||||
|
||||
# maintain the same connection across all threads
|
||||
from sqlalchemy.pool import StaticPool
|
||||
engine = create_engine('sqlite:///mydb.db',
|
||||
poolclass=StaticPool)
|
||||
|
||||
Note that :class:`.SingletonThreadPool` should be configured for the number
|
||||
of threads that are to be used; beyond that number, connections will be
|
||||
closed out in a non deterministic way.
|
||||
|
||||
|
||||
Dealing with Mixed String / Binary Columns
|
||||
------------------------------------------------------
|
||||
|
||||
The SQLite database is weakly typed, and as such it is possible when using
|
||||
binary values, which in Python are represented as ``b'some string'``, that a
|
||||
particular SQLite database can have data values within different rows where
|
||||
some of them will be returned as a ``b''`` value by the Pysqlite driver, and
|
||||
others will be returned as Python strings, e.g. ``''`` values. This situation
|
||||
is not known to occur if the SQLAlchemy :class:`.LargeBinary` datatype is used
|
||||
consistently, however if a particular SQLite database has data that was
|
||||
inserted using the Pysqlite driver directly, or when using the SQLAlchemy
|
||||
:class:`.String` type which was later changed to :class:`.LargeBinary`, the
|
||||
table will not be consistently readable because SQLAlchemy's
|
||||
:class:`.LargeBinary` datatype does not handle strings so it has no way of
|
||||
"encoding" a value that is in string format.
|
||||
|
||||
To deal with a SQLite table that has mixed string / binary data in the
|
||||
same column, use a custom type that will check each row individually::
|
||||
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy import TypeDecorator
|
||||
|
||||
class MixedBinary(TypeDecorator):
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if isinstance(value, str):
|
||||
value = bytes(value, 'utf-8')
|
||||
elif value is not None:
|
||||
value = bytes(value)
|
||||
|
||||
return value
|
||||
|
||||
Then use the above ``MixedBinary`` datatype in the place where
|
||||
:class:`.LargeBinary` would normally be used.
|
||||
|
||||
.. _pysqlite_serializable:
|
||||
|
||||
Serializable isolation / Savepoints / Transactional DDL
|
||||
-------------------------------------------------------
|
||||
|
||||
In the section :ref:`sqlite_concurrency`, we refer to the pysqlite
|
||||
driver's assortment of issues that prevent several features of SQLite
|
||||
from working correctly. The pysqlite DBAPI driver has several
|
||||
long-standing bugs which impact the correctness of its transactional
|
||||
behavior. In its default mode of operation, SQLite features such as
|
||||
SERIALIZABLE isolation, transactional DDL, and SAVEPOINT support are
|
||||
non-functional, and in order to use these features, workarounds must
|
||||
be taken.
|
||||
|
||||
The issue is essentially that the driver attempts to second-guess the user's
|
||||
intent, failing to start transactions and sometimes ending them prematurely, in
|
||||
an effort to minimize the SQLite databases's file locking behavior, even
|
||||
though SQLite itself uses "shared" locks for read-only activities.
|
||||
|
||||
SQLAlchemy chooses to not alter this behavior by default, as it is the
|
||||
long-expected behavior of the pysqlite driver; if and when the pysqlite
|
||||
driver attempts to repair these issues, that will be more of a driver towards
|
||||
defaults for SQLAlchemy.
|
||||
|
||||
The good news is that with a few events, we can implement transactional
|
||||
support fully, by disabling pysqlite's feature entirely and emitting BEGIN
|
||||
ourselves. This is achieved using two event listeners::
|
||||
|
||||
from sqlalchemy import create_engine, event
|
||||
|
||||
engine = create_engine("sqlite:///myfile.db")
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_connect(dbapi_connection, connection_record):
|
||||
# disable pysqlite's emitting of the BEGIN statement entirely.
|
||||
# also stops it from emitting COMMIT before any DDL.
|
||||
dbapi_connection.isolation_level = None
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
# emit our own BEGIN
|
||||
conn.exec_driver_sql("BEGIN")
|
||||
|
||||
.. warning:: When using the above recipe, it is advised to not use the
|
||||
:paramref:`.Connection.execution_options.isolation_level` setting on
|
||||
:class:`_engine.Connection` and :func:`_sa.create_engine`
|
||||
with the SQLite driver,
|
||||
as this function necessarily will also alter the ".isolation_level" setting.
|
||||
|
||||
|
||||
Above, we intercept a new pysqlite connection and disable any transactional
|
||||
integration. Then, at the point at which SQLAlchemy knows that transaction
|
||||
scope is to begin, we emit ``"BEGIN"`` ourselves.
|
||||
|
||||
When we take control of ``"BEGIN"``, we can also control directly SQLite's
|
||||
locking modes, introduced at
|
||||
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_,
|
||||
by adding the desired locking mode to our ``"BEGIN"``::
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
conn.exec_driver_sql("BEGIN EXCLUSIVE")
|
||||
|
||||
.. seealso::
|
||||
|
||||
`BEGIN TRANSACTION <https://sqlite.org/lang_transaction.html>`_ -
|
||||
on the SQLite site
|
||||
|
||||
`sqlite3 SELECT does not BEGIN a transaction <https://bugs.python.org/issue9924>`_ -
|
||||
on the Python bug tracker
|
||||
|
||||
`sqlite3 module breaks transactions and potentially corrupts data <https://bugs.python.org/issue10740>`_ -
|
||||
on the Python bug tracker
|
||||
|
||||
.. _pysqlite_udfs:
|
||||
|
||||
User-Defined Functions
|
||||
----------------------
|
||||
|
||||
pysqlite supports a `create_function() <https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function>`_
|
||||
method that allows us to create our own user-defined functions (UDFs) in Python and use them directly in SQLite queries.
|
||||
These functions are registered with a specific DBAPI Connection.
|
||||
|
||||
SQLAlchemy uses connection pooling with file-based SQLite databases, so we need to ensure that the UDF is attached to the
|
||||
connection when it is created. That is accomplished with an event listener::
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
def udf():
|
||||
return "udf-ok"
|
||||
|
||||
|
||||
engine = create_engine("sqlite:///./db_file")
|
||||
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def connect(conn, rec):
|
||||
conn.create_function("udf", 0, udf)
|
||||
|
||||
|
||||
for i in range(5):
|
||||
with engine.connect() as conn:
|
||||
print(conn.scalar(text("SELECT UDF()")))
|
||||
|
||||
|
||||
""" # noqa
|
||||
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
|
||||
from .base import DATE
|
||||
from .base import DATETIME
|
||||
from .base import SQLiteDialect
|
||||
from ... import exc
|
||||
from ... import pool
|
||||
from ... import types as sqltypes
|
||||
from ... import util
|
||||
|
||||
|
||||
class _SQLite_pysqliteTimeStamp(DATETIME):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATETIME.bind_processor(self, dialect)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATETIME.result_processor(self, dialect, coltype)
|
||||
|
||||
|
||||
class _SQLite_pysqliteDate(DATE):
|
||||
def bind_processor(self, dialect):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.bind_processor(self, dialect)
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
if dialect.native_datetime:
|
||||
return None
|
||||
else:
|
||||
return DATE.result_processor(self, dialect, coltype)
|
||||
|
||||
|
||||
class SQLiteDialect_pysqlite(SQLiteDialect):
|
||||
default_paramstyle = "qmark"
|
||||
supports_statement_cache = True
|
||||
returns_native_bytes = True
|
||||
|
||||
colspecs = util.update_copy(
|
||||
SQLiteDialect.colspecs,
|
||||
{
|
||||
sqltypes.Date: _SQLite_pysqliteDate,
|
||||
sqltypes.TIMESTAMP: _SQLite_pysqliteTimeStamp,
|
||||
},
|
||||
)
|
||||
|
||||
description_encoding = None
|
||||
|
||||
driver = "pysqlite"
|
||||
|
||||
@classmethod
|
||||
def import_dbapi(cls):
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
|
||||
return sqlite
|
||||
|
||||
@classmethod
|
||||
def _is_url_file_db(cls, url):
|
||||
if (url.database and url.database != ":memory:") and (
|
||||
url.query.get("mode", None) != "memory"
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def get_pool_class(cls, url):
|
||||
if cls._is_url_file_db(url):
|
||||
return pool.QueuePool
|
||||
else:
|
||||
return pool.SingletonThreadPool
|
||||
|
||||
def _get_server_version_info(self, connection):
|
||||
return self.dbapi.sqlite_version_info
|
||||
|
||||
_isolation_lookup = SQLiteDialect._isolation_lookup.union(
|
||||
{
|
||||
"AUTOCOMMIT": None,
|
||||
}
|
||||
)
|
||||
|
||||
def set_isolation_level(self, dbapi_connection, level):
|
||||
if level == "AUTOCOMMIT":
|
||||
dbapi_connection.isolation_level = None
|
||||
else:
|
||||
dbapi_connection.isolation_level = ""
|
||||
return super().set_isolation_level(dbapi_connection, level)
|
||||
|
||||
def on_connect(self):
|
||||
def regexp(a, b):
|
||||
if b is None:
|
||||
return None
|
||||
return re.search(a, b) is not None
|
||||
|
||||
if util.py38 and self._get_server_version_info(None) >= (3, 9):
|
||||
# sqlite must be greater than 3.8.3 for deterministic=True
|
||||
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
|
||||
# the check is more conservative since there were still issues
|
||||
# with following 3.8 sqlite versions
|
||||
create_func_kw = {"deterministic": True}
|
||||
else:
|
||||
create_func_kw = {}
|
||||
|
||||
def set_regexp(dbapi_connection):
|
||||
dbapi_connection.create_function(
|
||||
"regexp", 2, regexp, **create_func_kw
|
||||
)
|
||||
|
||||
def floor_func(dbapi_connection):
|
||||
# NOTE: floor is optionally present in sqlite 3.35+ , however
|
||||
# as it is normally non-present we deliver floor() unconditionally
|
||||
# for now.
|
||||
# https://www.sqlite.org/lang_mathfunc.html
|
||||
dbapi_connection.create_function(
|
||||
"floor", 1, math.floor, **create_func_kw
|
||||
)
|
||||
|
||||
fns = [set_regexp, floor_func]
|
||||
|
||||
def connect(conn):
|
||||
for fn in fns:
|
||||
fn(conn)
|
||||
|
||||
return connect
|
||||
|
||||
def create_connect_args(self, url):
|
||||
if url.username or url.password or url.host or url.port:
|
||||
raise exc.ArgumentError(
|
||||
"Invalid SQLite URL: %s\n"
|
||||
"Valid SQLite URL forms are:\n"
|
||||
" sqlite:///:memory: (or, sqlite://)\n"
|
||||
" sqlite:///relative/path/to/file.db\n"
|
||||
" sqlite:////absolute/path/to/file.db" % (url,)
|
||||
)
|
||||
|
||||
# theoretically, this list can be augmented, at least as far as
|
||||
# parameter names accepted by sqlite3/pysqlite, using
|
||||
# inspect.getfullargspec(). for the moment this seems like overkill
|
||||
# as these parameters don't change very often, and as always,
|
||||
# parameters passed to connect_args will always go to the
|
||||
# sqlite3/pysqlite driver.
|
||||
pysqlite_args = [
|
||||
("uri", bool),
|
||||
("timeout", float),
|
||||
("isolation_level", str),
|
||||
("detect_types", int),
|
||||
("check_same_thread", bool),
|
||||
("cached_statements", int),
|
||||
]
|
||||
opts = url.query
|
||||
pysqlite_opts = {}
|
||||
for key, type_ in pysqlite_args:
|
||||
util.coerce_kw_type(opts, key, type_, dest=pysqlite_opts)
|
||||
|
||||
if pysqlite_opts.get("uri", False):
|
||||
uri_opts = dict(opts)
|
||||
# here, we are actually separating the parameters that go to
|
||||
# sqlite3/pysqlite vs. those that go the SQLite URI. What if
|
||||
# two names conflict? again, this seems to be not the case right
|
||||
# now, and in the case that new names are added to
|
||||
# either side which overlap, again the sqlite3/pysqlite parameters
|
||||
# can be passed through connect_args instead of in the URL.
|
||||
# If SQLite native URIs add a parameter like "timeout" that
|
||||
# we already have listed here for the python driver, then we need
|
||||
# to adjust for that here.
|
||||
for key, type_ in pysqlite_args:
|
||||
uri_opts.pop(key, None)
|
||||
filename = url.database
|
||||
if uri_opts:
|
||||
# sorting of keys is for unit test support
|
||||
filename += "?" + (
|
||||
"&".join(
|
||||
"%s=%s" % (key, uri_opts[key])
|
||||
for key in sorted(uri_opts)
|
||||
)
|
||||
)
|
||||
else:
|
||||
filename = url.database or ":memory:"
|
||||
if filename != ":memory:":
|
||||
filename = os.path.abspath(filename)
|
||||
|
||||
pysqlite_opts.setdefault(
|
||||
"check_same_thread", not self._is_url_file_db(url)
|
||||
)
|
||||
|
||||
return ([filename], pysqlite_opts)
|
||||
|
||||
def is_disconnect(self, e, connection, cursor):
|
||||
return isinstance(
|
||||
e, self.dbapi.ProgrammingError
|
||||
) and "Cannot operate on a closed database." in str(e)
|
||||
|
||||
|
||||
dialect = SQLiteDialect_pysqlite
|
||||
|
||||
|
||||
class _SQLiteDialect_pysqlite_numeric(SQLiteDialect_pysqlite):
|
||||
"""numeric dialect for testing only
|
||||
|
||||
internal use only. This dialect is **NOT** supported by SQLAlchemy
|
||||
and may change at any time.
|
||||
|
||||
"""
|
||||
|
||||
supports_statement_cache = True
|
||||
default_paramstyle = "numeric"
|
||||
driver = "pysqlite_numeric"
|
||||
|
||||
_first_bind = ":1"
|
||||
_not_in_statement_regexp = None
|
||||
|
||||
def __init__(self, *arg, **kw):
|
||||
kw.setdefault("paramstyle", "numeric")
|
||||
super().__init__(*arg, **kw)
|
||||
|
||||
def create_connect_args(self, url):
|
||||
arg, opts = super().create_connect_args(url)
|
||||
opts["factory"] = self._fix_sqlite_issue_99953()
|
||||
return arg, opts
|
||||
|
||||
def _fix_sqlite_issue_99953(self):
|
||||
import sqlite3
|
||||
|
||||
first_bind = self._first_bind
|
||||
if self._not_in_statement_regexp:
|
||||
nis = self._not_in_statement_regexp
|
||||
|
||||
def _test_sql(sql):
|
||||
m = nis.search(sql)
|
||||
assert not m, f"Found {nis.pattern!r} in {sql!r}"
|
||||
|
||||
else:
|
||||
|
||||
def _test_sql(sql):
|
||||
pass
|
||||
|
||||
def _numeric_param_as_dict(parameters):
|
||||
if parameters:
|
||||
assert isinstance(parameters, tuple)
|
||||
return {
|
||||
str(idx): value for idx, value in enumerate(parameters, 1)
|
||||
}
|
||||
else:
|
||||
return ()
|
||||
|
||||
class SQLiteFix99953Cursor(sqlite3.Cursor):
|
||||
def execute(self, sql, parameters=()):
|
||||
_test_sql(sql)
|
||||
if first_bind in sql:
|
||||
parameters = _numeric_param_as_dict(parameters)
|
||||
return super().execute(sql, parameters)
|
||||
|
||||
def executemany(self, sql, parameters):
|
||||
_test_sql(sql)
|
||||
if first_bind in sql:
|
||||
parameters = [
|
||||
_numeric_param_as_dict(p) for p in parameters
|
||||
]
|
||||
return super().executemany(sql, parameters)
|
||||
|
||||
class SQLiteFix99953Connection(sqlite3.Connection):
|
||||
def cursor(self, factory=None):
|
||||
if factory is None:
|
||||
factory = SQLiteFix99953Cursor
|
||||
return super().cursor(factory=factory)
|
||||
|
||||
def execute(self, sql, parameters=()):
|
||||
_test_sql(sql)
|
||||
if first_bind in sql:
|
||||
parameters = _numeric_param_as_dict(parameters)
|
||||
return super().execute(sql, parameters)
|
||||
|
||||
def executemany(self, sql, parameters):
|
||||
_test_sql(sql)
|
||||
if first_bind in sql:
|
||||
parameters = [
|
||||
_numeric_param_as_dict(p) for p in parameters
|
||||
]
|
||||
return super().executemany(sql, parameters)
|
||||
|
||||
return SQLiteFix99953Connection
|
||||
|
||||
|
||||
class _SQLiteDialect_pysqlite_dollar(_SQLiteDialect_pysqlite_numeric):
|
||||
"""numeric dialect that uses $ for testing only
|
||||
|
||||
internal use only. This dialect is **NOT** supported by SQLAlchemy
|
||||
and may change at any time.
|
||||
|
||||
"""
|
||||
|
||||
supports_statement_cache = True
|
||||
default_paramstyle = "numeric_dollar"
|
||||
driver = "pysqlite_dollar"
|
||||
|
||||
_first_bind = "$1"
|
||||
_not_in_statement_regexp = re.compile(r"[^\d]:\d+")
|
||||
|
||||
def __init__(self, *arg, **kw):
|
||||
kw.setdefault("paramstyle", "numeric_dollar")
|
||||
super().__init__(*arg, **kw)
|
|
@ -0,0 +1,145 @@
|
|||
Rules for Migrating TypeEngine classes to 0.6
|
||||
---------------------------------------------
|
||||
|
||||
1. the TypeEngine classes are used for:
|
||||
|
||||
a. Specifying behavior which needs to occur for bind parameters
|
||||
or result row columns.
|
||||
|
||||
b. Specifying types that are entirely specific to the database
|
||||
in use and have no analogue in the sqlalchemy.types package.
|
||||
|
||||
c. Specifying types where there is an analogue in sqlalchemy.types,
|
||||
but the database in use takes vendor-specific flags for those
|
||||
types.
|
||||
|
||||
d. If a TypeEngine class doesn't provide any of this, it should be
|
||||
*removed* from the dialect.
|
||||
|
||||
2. the TypeEngine classes are *no longer* used for generating DDL. Dialects
|
||||
now have a TypeCompiler subclass which uses the same visit_XXX model as
|
||||
other compilers.
|
||||
|
||||
3. the "ischema_names" and "colspecs" dictionaries are now required members on
|
||||
the Dialect class.
|
||||
|
||||
4. The names of types within dialects are now important. If a dialect-specific type
|
||||
is a subclass of an existing generic type and is only provided for bind/result behavior,
|
||||
the current mixed case naming can remain, i.e. _PGNumeric for Numeric - in this case,
|
||||
end users would never need to use _PGNumeric directly. However, if a dialect-specific
|
||||
type is specifying a type *or* arguments that are not present generically, it should
|
||||
match the real name of the type on that backend, in uppercase. E.g. postgresql.INET,
|
||||
mysql.ENUM, postgresql.ARRAY.
|
||||
|
||||
Or follow this handy flowchart:
|
||||
|
||||
is the type meant to provide bind/result is the type the same name as an
|
||||
behavior to a generic type (i.e. MixedCase) ---- no ---> UPPERCASE type in types.py ?
|
||||
type in types.py ? | |
|
||||
| no yes
|
||||
yes | |
|
||||
| | does your type need special
|
||||
| +<--- yes --- behavior or arguments ?
|
||||
| | |
|
||||
| | no
|
||||
name the type using | |
|
||||
_MixedCase, i.e. v V
|
||||
_OracleBoolean. it name the type don't make a
|
||||
stays private to the dialect identically as that type, make sure the dialect's
|
||||
and is invoked *only* via within the DB, base.py imports the types.py
|
||||
the colspecs dict. using UPPERCASE UPPERCASE name into its namespace
|
||||
| (i.e. BIT, NCHAR, INTERVAL).
|
||||
| Users can import it.
|
||||
| |
|
||||
v v
|
||||
subclass the closest is the name of this type
|
||||
MixedCase type types.py, identical to an UPPERCASE
|
||||
i.e. <--- no ------- name in types.py ?
|
||||
class _DateTime(types.DateTime),
|
||||
class DATETIME2(types.DateTime), |
|
||||
class BIT(types.TypeEngine). yes
|
||||
|
|
||||
v
|
||||
the type should
|
||||
subclass the
|
||||
UPPERCASE
|
||||
type in types.py
|
||||
(i.e. class BLOB(types.BLOB))
|
||||
|
||||
|
||||
Example 1. pysqlite needs bind/result processing for the DateTime type in types.py,
|
||||
which applies to all DateTimes and subclasses. It's named _SLDateTime and
|
||||
subclasses types.DateTime.
|
||||
|
||||
Example 2. MS-SQL has a TIME type which takes a non-standard "precision" argument
|
||||
that is rendered within DDL. So it's named TIME in the MS-SQL dialect's base.py,
|
||||
and subclasses types.TIME. Users can then say mssql.TIME(precision=10).
|
||||
|
||||
Example 3. MS-SQL dialects also need special bind/result processing for date
|
||||
But its DATE type doesn't render DDL differently than that of a plain
|
||||
DATE, i.e. it takes no special arguments. Therefore we are just adding behavior
|
||||
to types.Date, so it's named _MSDate in the MS-SQL dialect's base.py, and subclasses
|
||||
types.Date.
|
||||
|
||||
Example 4. MySQL has a SET type, there's no analogue for this in types.py. So
|
||||
MySQL names it SET in the dialect's base.py, and it subclasses types.String, since
|
||||
it ultimately deals with strings.
|
||||
|
||||
Example 5. PostgreSQL has a DATETIME type. The DBAPIs handle dates correctly,
|
||||
and no special arguments are used in PG's DDL beyond what types.py provides.
|
||||
PostgreSQL dialect therefore imports types.DATETIME into its base.py.
|
||||
|
||||
Ideally one should be able to specify a schema using names imported completely from a
|
||||
dialect, all matching the real name on that backend:
|
||||
|
||||
from sqlalchemy.dialects.postgresql import base as pg
|
||||
|
||||
t = Table('mytable', metadata,
|
||||
Column('id', pg.INTEGER, primary_key=True),
|
||||
Column('name', pg.VARCHAR(300)),
|
||||
Column('inetaddr', pg.INET)
|
||||
)
|
||||
|
||||
where above, the INTEGER and VARCHAR types are ultimately from sqlalchemy.types,
|
||||
but the PG dialect makes them available in its own namespace.
|
||||
|
||||
5. "colspecs" now is a dictionary of generic or uppercased types from sqlalchemy.types
|
||||
linked to types specified in the dialect. Again, if a type in the dialect does not
|
||||
specify any special behavior for bind_processor() or result_processor() and does not
|
||||
indicate a special type only available in this database, it must be *removed* from the
|
||||
module and from this dictionary.
|
||||
|
||||
6. "ischema_names" indicates string descriptions of types as returned from the database
|
||||
linked to TypeEngine classes.
|
||||
|
||||
a. The string name should be matched to the most specific type possible within
|
||||
sqlalchemy.types, unless there is no matching type within sqlalchemy.types in which
|
||||
case it points to a dialect type. *It doesn't matter* if the dialect has its
|
||||
own subclass of that type with special bind/result behavior - reflect to the types.py
|
||||
UPPERCASE type as much as possible. With very few exceptions, all types
|
||||
should reflect to an UPPERCASE type.
|
||||
|
||||
b. If the dialect contains a matching dialect-specific type that takes extra arguments
|
||||
which the generic one does not, then point to the dialect-specific type. E.g.
|
||||
mssql.VARCHAR takes a "collation" parameter which should be preserved.
|
||||
|
||||
5. DDL, or what was formerly issued by "get_col_spec()", is now handled exclusively by
|
||||
a subclass of compiler.GenericTypeCompiler.
|
||||
|
||||
a. your TypeCompiler class will receive generic and uppercase types from
|
||||
sqlalchemy.types. Do not assume the presence of dialect-specific attributes on
|
||||
these types.
|
||||
|
||||
b. the visit_UPPERCASE methods on GenericTypeCompiler should *not* be overridden with
|
||||
methods that produce a different DDL name. Uppercase types don't do any kind of
|
||||
"guessing" - if visit_TIMESTAMP is called, the DDL should render as TIMESTAMP in
|
||||
all cases, regardless of whether or not that type is legal on the backend database.
|
||||
|
||||
c. the visit_UPPERCASE methods *should* be overridden with methods that add additional
|
||||
arguments and flags to those types.
|
||||
|
||||
d. the visit_lowercase methods are overridden to provide an interpretation of a generic
|
||||
type. E.g. visit_large_binary() might be overridden to say "return self.visit_BIT(type_)".
|
||||
|
||||
e. visit_lowercase methods should *never* render strings directly - it should always
|
||||
be via calling a visit_UPPERCASE() method.
|
Loading…
Add table
Add a link
Reference in a new issue