Cleaned up the directories
This commit is contained in:
parent
f708506d68
commit
a683fcffea
1340 changed files with 554582 additions and 6840 deletions
|
@ -0,0 +1,42 @@
|
|||
from .database import ( # noqa
|
||||
create_database,
|
||||
database_exists,
|
||||
drop_database,
|
||||
escape_like,
|
||||
has_index,
|
||||
has_unique_index,
|
||||
is_auto_assigned_date_column,
|
||||
json_sql,
|
||||
jsonb_sql
|
||||
)
|
||||
from .foreign_keys import ( # noqa
|
||||
dependent_objects,
|
||||
get_fk_constraint_for_columns,
|
||||
get_referencing_foreign_keys,
|
||||
group_foreign_keys,
|
||||
merge_references,
|
||||
non_indexed_foreign_keys
|
||||
)
|
||||
from .mock import create_mock_engine, mock_engine # noqa
|
||||
from .orm import ( # noqa
|
||||
cast_if,
|
||||
get_bind,
|
||||
get_class_by_table,
|
||||
get_column_key,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_primary_keys,
|
||||
get_tables,
|
||||
get_type,
|
||||
getdotattr,
|
||||
has_changes,
|
||||
identity,
|
||||
is_loaded,
|
||||
naturally_equivalent,
|
||||
quote,
|
||||
table_name
|
||||
)
|
||||
from .render import render_expression, render_statement # noqa
|
||||
from .sort_query import make_order_by_deterministic # noqa
|
|
@ -0,0 +1,659 @@
|
|||
import itertools
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import copy
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
|
||||
from ..utils import starts_with
|
||||
from .orm import quote
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escape the string paremeter used in SQL LIKE expressions.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import escape_like
|
||||
|
||||
|
||||
query = session.query(User).filter(
|
||||
User.name.ilike(escape_like('John'))
|
||||
)
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
||||
|
||||
|
||||
def json_sql(value, scalars_to_json=True):
|
||||
"""
|
||||
Convert python data structures to PostgreSQL specific SQLAlchemy JSON
|
||||
constructs. This function is extremly useful if you need to build
|
||||
PostgreSQL JSON on python side.
|
||||
|
||||
.. note::
|
||||
|
||||
This function needs PostgreSQL >= 9.4
|
||||
|
||||
Scalars are converted to to_json SQLAlchemy function objects
|
||||
|
||||
::
|
||||
|
||||
json_sql(1) # Equals SQL: to_json(1)
|
||||
|
||||
json_sql('a') # to_json('a')
|
||||
|
||||
|
||||
Mappings are converted to json_build_object constructs
|
||||
|
||||
::
|
||||
|
||||
json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5)
|
||||
|
||||
|
||||
Sequences (other than strings) are converted to json_build_array constructs
|
||||
|
||||
::
|
||||
|
||||
json_sql([1, 2, 3]) # json_build_array(1, 2, 3)
|
||||
|
||||
|
||||
You can also nest these data structures
|
||||
|
||||
::
|
||||
|
||||
json_sql({'a': [1, 2, 3]})
|
||||
# json_build_object('a', json_build_array[1, 2, 3])
|
||||
|
||||
|
||||
:param value:
|
||||
value to be converted to SQLAlchemy PostgreSQL function constructs
|
||||
"""
|
||||
scalar_convert = sa.text
|
||||
if scalars_to_json:
|
||||
def scalar_convert(a):
|
||||
return sa.func.to_json(sa.text(a))
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return sa.func.json_build_object(
|
||||
*(
|
||||
json_sql(v, scalars_to_json=False)
|
||||
for v in itertools.chain(*value.items())
|
||||
)
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
return scalar_convert(f"'{value}'")
|
||||
elif isinstance(value, Sequence):
|
||||
return sa.func.json_build_array(
|
||||
*(
|
||||
json_sql(v, scalars_to_json=False)
|
||||
for v in value
|
||||
)
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
return scalar_convert(str(value))
|
||||
return value
|
||||
|
||||
|
||||
def jsonb_sql(value, scalars_to_jsonb=True):
|
||||
"""
|
||||
Convert python data structures to PostgreSQL specific SQLAlchemy JSONB
|
||||
constructs. This function is extremly useful if you need to build
|
||||
PostgreSQL JSONB on python side.
|
||||
|
||||
.. note::
|
||||
|
||||
This function needs PostgreSQL >= 9.4
|
||||
|
||||
Scalars are converted to to_jsonb SQLAlchemy function objects
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql(1) # Equals SQL: to_jsonb(1)
|
||||
|
||||
jsonb_sql('a') # to_jsonb('a')
|
||||
|
||||
|
||||
Mappings are converted to jsonb_build_object constructs
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql({'a': 'c', '2': 5}) # jsonb_build_object('a', 'c', '2', 5)
|
||||
|
||||
|
||||
Sequences (other than strings) converted to jsonb_build_array constructs
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql([1, 2, 3]) # jsonb_build_array(1, 2, 3)
|
||||
|
||||
|
||||
You can also nest these data structures
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql({'a': [1, 2, 3]})
|
||||
# jsonb_build_object('a', jsonb_build_array[1, 2, 3])
|
||||
|
||||
|
||||
:param value:
|
||||
value to be converted to SQLAlchemy PostgreSQL function constructs
|
||||
:boolean jsonbb:
|
||||
Flag to alternatively convert the return with a to_jsonb construct
|
||||
"""
|
||||
scalar_convert = sa.text
|
||||
if scalars_to_jsonb:
|
||||
def scalar_convert(a):
|
||||
return sa.func.to_jsonb(sa.text(a))
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return sa.func.jsonb_build_object(
|
||||
*(
|
||||
jsonb_sql(v, scalars_to_jsonb=False)
|
||||
for v in itertools.chain(*value.items())
|
||||
)
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
return scalar_convert(f"'{value}'")
|
||||
elif isinstance(value, Sequence):
|
||||
return sa.func.jsonb_build_array(
|
||||
*(
|
||||
jsonb_sql(v, scalars_to_jsonb=False)
|
||||
for v in value
|
||||
)
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
return scalar_convert(str(value))
|
||||
return value
|
||||
|
||||
|
||||
def has_index(column_or_constraint):
|
||||
"""
|
||||
Return whether or not given column or the columns of given foreign key
|
||||
constraint have an index. A column has an index if it has a single column
|
||||
index or it is the first column in compound column index.
|
||||
|
||||
A foreign key constraint has an index if the constraint columns are the
|
||||
first columns in compound column index.
|
||||
|
||||
:param column_or_constraint:
|
||||
SQLAlchemy Column object or SA ForeignKeyConstraint object
|
||||
|
||||
.. versionadded: 0.26.2
|
||||
|
||||
.. versionchanged: 0.30.18
|
||||
Added support for foreign key constaints.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_index
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
is_published = sa.Column(sa.Boolean, index=True)
|
||||
is_deleted = sa.Column(sa.Boolean)
|
||||
is_archived = sa.Column(sa.Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index('my_index', is_deleted, is_archived),
|
||||
)
|
||||
|
||||
|
||||
table = Article.__table__
|
||||
|
||||
has_index(table.c.is_published) # True
|
||||
has_index(table.c.is_deleted) # True
|
||||
has_index(table.c.is_archived) # False
|
||||
|
||||
|
||||
Also supports primary key indexes
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_index
|
||||
|
||||
|
||||
class ArticleTranslation(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
locale = sa.Column(sa.String(10), primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
|
||||
|
||||
table = ArticleTranslation.__table__
|
||||
|
||||
has_index(table.c.locale) # False
|
||||
has_index(table.c.id) # True
|
||||
|
||||
|
||||
This function supports foreign key constraints as well
|
||||
|
||||
::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
author_last_name = sa.Column(sa.Unicode(255))
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
[author_first_name, author_last_name],
|
||||
[User.first_name, User.last_name]
|
||||
),
|
||||
sa.Index(
|
||||
'my_index',
|
||||
author_first_name,
|
||||
author_last_name
|
||||
)
|
||||
)
|
||||
|
||||
table = Article.__table__
|
||||
constraint = list(table.foreign_keys)[0].constraint
|
||||
|
||||
has_index(constraint) # True
|
||||
"""
|
||||
table = column_or_constraint.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
primary_keys = table.primary_key.columns.values()
|
||||
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
|
||||
columns = list(column_or_constraint.columns.values())
|
||||
else:
|
||||
columns = [column_or_constraint]
|
||||
|
||||
return (
|
||||
(primary_keys and starts_with(primary_keys, columns)) or
|
||||
any(
|
||||
starts_with(index.columns.values(), columns)
|
||||
for index in table.indexes
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def has_unique_index(column_or_constraint):
|
||||
"""
|
||||
Return whether or not given column or given foreign key constraint has a
|
||||
unique index.
|
||||
|
||||
A column has a unique index if it has a single column primary key index or
|
||||
it has a single column UniqueConstraint.
|
||||
|
||||
A foreign key constraint has a unique index if the columns of the
|
||||
constraint are the same as the columns of table primary key or the coluns
|
||||
of any unique index or any unique constraint of the given table.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
|
||||
.. versionadded: 0.27.1
|
||||
|
||||
.. versionchanged: 0.30.18
|
||||
Added support for foreign key constaints.
|
||||
|
||||
Fixed support for unique indexes (previously only worked for unique
|
||||
constraints)
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_unique_index
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
is_published = sa.Column(sa.Boolean, unique=True)
|
||||
is_deleted = sa.Column(sa.Boolean)
|
||||
is_archived = sa.Column(sa.Boolean)
|
||||
|
||||
|
||||
table = Article.__table__
|
||||
|
||||
has_unique_index(table.c.is_published) # True
|
||||
has_unique_index(table.c.is_deleted) # False
|
||||
has_unique_index(table.c.id) # True
|
||||
|
||||
|
||||
This function supports foreign key constraints as well
|
||||
|
||||
::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
author_last_name = sa.Column(sa.Unicode(255))
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
[author_first_name, author_last_name],
|
||||
[User.first_name, User.last_name]
|
||||
),
|
||||
sa.Index(
|
||||
'my_index',
|
||||
author_first_name,
|
||||
author_last_name,
|
||||
unique=True
|
||||
)
|
||||
)
|
||||
|
||||
table = Article.__table__
|
||||
constraint = list(table.foreign_keys)[0].constraint
|
||||
|
||||
has_unique_index(constraint) # True
|
||||
|
||||
|
||||
:raises TypeError: if given column does not belong to a Table object
|
||||
"""
|
||||
table = column_or_constraint.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
primary_keys = list(table.primary_key.columns.values())
|
||||
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
|
||||
columns = list(column_or_constraint.columns.values())
|
||||
else:
|
||||
columns = [column_or_constraint]
|
||||
|
||||
return (
|
||||
(columns == primary_keys) or
|
||||
any(
|
||||
columns == list(constraint.columns.values())
|
||||
for constraint in table.constraints
|
||||
if isinstance(constraint, sa.sql.schema.UniqueConstraint)
|
||||
) or
|
||||
any(
|
||||
columns == list(index.columns.values())
|
||||
for index in table.indexes
|
||||
if index.unique
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_auto_assigned_date_column(column):
|
||||
"""
|
||||
Returns whether or not given SQLAlchemy Column object's is auto assigned
|
||||
DateTime or Date.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
"""
|
||||
return (
|
||||
(
|
||||
isinstance(column.type, sa.DateTime) or
|
||||
isinstance(column.type, sa.Date)
|
||||
) and
|
||||
(
|
||||
column.default or
|
||||
column.server_default or
|
||||
column.onupdate or
|
||||
column.server_onupdate
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _set_url_database(url: sa.engine.url.URL, database):
|
||||
"""Set the database of an engine URL.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
:param database: New database to set.
|
||||
|
||||
"""
|
||||
if hasattr(url, '_replace'):
|
||||
# Cannot use URL.set() as database may need to be set to None.
|
||||
ret = url._replace(database=database)
|
||||
else: # SQLAlchemy <1.4
|
||||
url = copy(url)
|
||||
url.database = database
|
||||
ret = url
|
||||
assert ret.database == database, ret
|
||||
return ret
|
||||
|
||||
|
||||
def _get_scalar_result(engine, sql):
|
||||
with engine.connect() as conn:
|
||||
return conn.scalar(sql)
|
||||
|
||||
|
||||
def _sqlite_file_exists(database):
|
||||
if not os.path.isfile(database) or os.path.getsize(database) < 100:
|
||||
return False
|
||||
|
||||
with open(database, 'rb') as f:
|
||||
header = f.read(100)
|
||||
|
||||
return header[:16] == b'SQLite format 3\x00'
|
||||
|
||||
|
||||
def database_exists(url):
|
||||
"""Check if a database exists.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
|
||||
Performs backend-specific testing to quickly determine if a database
|
||||
exists on the server. ::
|
||||
|
||||
database_exists('postgresql://postgres@localhost/name') #=> False
|
||||
create_database('postgresql://postgres@localhost/name')
|
||||
database_exists('postgresql://postgres@localhost/name') #=> True
|
||||
|
||||
Supports checking against a constructed URL as well. ::
|
||||
|
||||
engine = create_engine('postgresql://postgres@localhost/name')
|
||||
database_exists(engine.url) #=> False
|
||||
create_database(engine.url)
|
||||
database_exists(engine.url) #=> True
|
||||
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
engine = None
|
||||
try:
|
||||
if dialect_name == 'postgresql':
|
||||
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
|
||||
for db in (database, 'postgres', 'template1', 'template0', None):
|
||||
url = _set_url_database(url, database=db)
|
||||
engine = sa.create_engine(url)
|
||||
try:
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
except (ProgrammingError, OperationalError):
|
||||
pass
|
||||
return False
|
||||
|
||||
elif dialect_name == 'mysql':
|
||||
url = _set_url_database(url, database=None)
|
||||
engine = sa.create_engine(url)
|
||||
text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
|
||||
"WHERE SCHEMA_NAME = '%s'" % database)
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
|
||||
elif dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
engine = sa.create_engine(url)
|
||||
if database:
|
||||
return database == ':memory:' or _sqlite_file_exists(database)
|
||||
else:
|
||||
# The default SQLAlchemy database is in memory, and :memory: is
|
||||
# not required, thus we should support that use case.
|
||||
return True
|
||||
else:
|
||||
text = 'SELECT 1'
|
||||
try:
|
||||
engine = sa.create_engine(url)
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
except (ProgrammingError, OperationalError):
|
||||
return False
|
||||
finally:
|
||||
if engine:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def create_database(url, encoding='utf8', template=None):
|
||||
"""Issue the appropriate CREATE DATABASE statement.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
:param encoding: The encoding to create the database as.
|
||||
:param template:
|
||||
The name of the template from which to create the new database. At the
|
||||
moment only supported by PostgreSQL driver.
|
||||
|
||||
To create a database, you can pass a simple URL that would have
|
||||
been passed to ``create_engine``. ::
|
||||
|
||||
create_database('postgresql://postgres@localhost/name')
|
||||
|
||||
You may also pass the url from an existing engine. ::
|
||||
|
||||
create_database(engine.url)
|
||||
|
||||
Has full support for mysql, postgres, and sqlite. In theory,
|
||||
other database engines should be supported.
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
dialect_driver = url.get_dialect().driver
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
url = _set_url_database(url, database="postgres")
|
||||
elif dialect_name == 'mssql':
|
||||
url = _set_url_database(url, database="master")
|
||||
elif dialect_name == 'cockroachdb':
|
||||
url = _set_url_database(url, database="defaultdb")
|
||||
elif not dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
|
||||
if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \
|
||||
or (dialect_name == 'postgresql' and dialect_driver in {
|
||||
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}):
|
||||
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
|
||||
else:
|
||||
engine = sa.create_engine(url)
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
if not template:
|
||||
template = 'template1'
|
||||
|
||||
with engine.begin() as conn:
|
||||
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
|
||||
quote(conn, database),
|
||||
encoding,
|
||||
quote(conn, template)
|
||||
)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
elif dialect_name == 'mysql':
|
||||
with engine.begin() as conn:
|
||||
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
|
||||
quote(conn, database),
|
||||
encoding
|
||||
)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
elif dialect_name == 'sqlite' and database != ':memory:':
|
||||
if database:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(sa.text('CREATE TABLE DB(id int)'))
|
||||
conn.execute(sa.text('DROP TABLE DB'))
|
||||
|
||||
else:
|
||||
with engine.begin() as conn:
|
||||
text = f'CREATE DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def drop_database(url):
|
||||
"""Issue the appropriate DROP DATABASE statement.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
|
||||
Works similar to the :ref:`create_database` method in that both url text
|
||||
and a constructed url are accepted. ::
|
||||
|
||||
drop_database('postgresql://postgres@localhost/name')
|
||||
drop_database(engine.url)
|
||||
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
dialect_driver = url.get_dialect().driver
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
url = _set_url_database(url, database="postgres")
|
||||
elif dialect_name == 'mssql':
|
||||
url = _set_url_database(url, database="master")
|
||||
elif dialect_name == 'cockroachdb':
|
||||
url = _set_url_database(url, database="defaultdb")
|
||||
elif not dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
|
||||
if dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}:
|
||||
engine = sa.create_engine(url, connect_args={'autocommit': True})
|
||||
elif dialect_name == 'postgresql' and dialect_driver in {
|
||||
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}:
|
||||
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
|
||||
else:
|
||||
engine = sa.create_engine(url)
|
||||
|
||||
if dialect_name == 'sqlite' and database != ':memory:':
|
||||
if database:
|
||||
os.remove(database)
|
||||
elif dialect_name == 'postgresql':
|
||||
with engine.begin() as conn:
|
||||
# Disconnect all users from the database we are dropping.
|
||||
version = conn.dialect.server_version_info
|
||||
pid_column = (
|
||||
'pid' if (version >= (9, 2)) else 'procpid'
|
||||
)
|
||||
text = '''
|
||||
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{database}'
|
||||
AND {pid_column} <> pg_backend_pid();
|
||||
'''.format(pid_column=pid_column, database=database)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
# Drop the database.
|
||||
text = f'DROP DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
else:
|
||||
with engine.begin() as conn:
|
||||
text = f'DROP DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
engine.dispose()
|
|
@ -0,0 +1,350 @@
|
|||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
|
||||
|
||||
from ..query_chain import QueryChain
|
||||
from .database import has_index
|
||||
from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
|
||||
|
||||
|
||||
def get_foreign_key_values(fk, obj):
|
||||
mapper = get_mapper(obj)
|
||||
return {
|
||||
fk.constraint.columns.values()[index]:
|
||||
getattr(obj, element.column.key)
|
||||
if hasattr(obj, element.column.key)
|
||||
else getattr(
|
||||
obj, mapper.get_property_by_column(element.column).key
|
||||
)
|
||||
for index, element in enumerate(fk.constraint.elements)
|
||||
}
|
||||
|
||||
|
||||
def group_foreign_keys(foreign_keys):
|
||||
"""
|
||||
Return a groupby iterator that groups given foreign keys by table.
|
||||
|
||||
:param foreign_keys: a sequence of foreign keys
|
||||
|
||||
|
||||
::
|
||||
|
||||
foreign_keys = get_referencing_foreign_keys(User)
|
||||
|
||||
for table, fks in group_foreign_keys(foreign_keys):
|
||||
# do something
|
||||
pass
|
||||
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
foreign_keys = sorted(
|
||||
foreign_keys, key=lambda key: key.constraint.table.name
|
||||
)
|
||||
return groupby(foreign_keys, lambda key: key.constraint.table)
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
"""
|
||||
Returns referencing foreign keys for given Table object or declarative
|
||||
class.
|
||||
|
||||
:param mixed:
|
||||
SA Table object or SA declarative class
|
||||
|
||||
::
|
||||
|
||||
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
|
||||
|
||||
get_referencing_foreign_keys(User.__table__)
|
||||
|
||||
|
||||
This function also understands inheritance. This means it returns
|
||||
all foreign keys that reference any table in the class inheritance tree.
|
||||
|
||||
Let's say you have three classes which use joined table inheritance,
|
||||
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
|
||||
TextItem.
|
||||
|
||||
::
|
||||
|
||||
# This will check all foreign keys that reference either article table
|
||||
# or textitem table.
|
||||
get_referencing_foreign_keys(Article)
|
||||
|
||||
.. seealso:: :func:`get_tables`
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
tables = [mixed]
|
||||
else:
|
||||
tables = get_tables(mixed)
|
||||
|
||||
referencing_foreign_keys = set()
|
||||
|
||||
for table in mixed.metadata.tables.values():
|
||||
if table not in tables:
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
||||
for fk in constraint.elements:
|
||||
if any(fk.references(t) for t in tables):
|
||||
referencing_foreign_keys.add(fk)
|
||||
return referencing_foreign_keys
|
||||
|
||||
|
||||
def merge_references(from_, to, foreign_keys=None):
|
||||
"""
|
||||
Merge the references of an entity into another entity.
|
||||
|
||||
Consider the following models::
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(name=%r)' % self.name
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(255))
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
|
||||
Now lets add some data::
|
||||
|
||||
john = self.User(name='John')
|
||||
jack = self.User(name='Jack')
|
||||
post = self.BlogPost(title='Some title', author=john)
|
||||
post2 = self.BlogPost(title='Other title', author=jack)
|
||||
self.session.add_all([
|
||||
john,
|
||||
jack,
|
||||
post,
|
||||
post2
|
||||
])
|
||||
self.session.commit()
|
||||
|
||||
|
||||
If we wanted to merge all John's references to Jack it would be as easy as
|
||||
::
|
||||
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
|
||||
post.author # User(name='Jack')
|
||||
post2.author # User(name='Jack')
|
||||
|
||||
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param foreign_keys: A sequence of foreign keys. By default this is None
|
||||
indicating all referencing foreign keys should be used.
|
||||
|
||||
.. seealso: :func:`dependent_objects`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
|
||||
.. versionchanged: 0.40.0
|
||||
|
||||
Removed possibility for old-style synchronize_session merging. Only
|
||||
SQL based merging supported for now.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise TypeError('The tables of given arguments do not match.')
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = get_referencing_foreign_keys(from_)
|
||||
|
||||
for fk in foreign_keys:
|
||||
old_values = get_foreign_key_values(fk, from_)
|
||||
new_values = get_foreign_key_values(fk, to)
|
||||
criteria = (
|
||||
getattr(fk.constraint.table.c, key.key) == value
|
||||
for key, value in old_values.items()
|
||||
)
|
||||
query = (
|
||||
fk.constraint.table.update()
|
||||
.where(sa.and_(*criteria))
|
||||
.values(
|
||||
{key.key: value for key, value in new_values.items()}
|
||||
)
|
||||
)
|
||||
session.execute(query)
|
||||
|
||||
|
||||
def dependent_objects(obj, foreign_keys=None):
|
||||
"""
|
||||
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
|
||||
through all dependent objects for given SQLAlchemy object.
|
||||
|
||||
Consider a User object is referenced in various articles and also in
|
||||
various orders. Getting all these dependent objects is as easy as::
|
||||
|
||||
from sqlalchemy_utils import dependent_objects
|
||||
|
||||
|
||||
dependent_objects(user)
|
||||
|
||||
|
||||
If you expect an object to have lots of dependent_objects it might be good
|
||||
to limit the results::
|
||||
|
||||
|
||||
dependent_objects(user).limit(5)
|
||||
|
||||
|
||||
|
||||
The common use case is checking for all restrict dependent objects before
|
||||
deleting parent object and inform the user if there are dependent objects
|
||||
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
|
||||
it will lead to nasty IntegrityErrors being raised.
|
||||
|
||||
In the following example we delete given user if it doesn't have any
|
||||
foreign key restricted dependent objects::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_referencing_foreign_keys
|
||||
|
||||
|
||||
user = session.query(User).get(some_user_id)
|
||||
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
# On most databases RESTRICT is the default mode hence we
|
||||
# check for None values also
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
)
|
||||
|
||||
if deps:
|
||||
# Do something to inform the user
|
||||
pass
|
||||
else:
|
||||
session.delete(user)
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param foreign_keys:
|
||||
A sequence of foreign keys to use for searching the dependent_objects
|
||||
for given object. By default this is None, indicating that all foreign
|
||||
keys referencing the object will be used.
|
||||
|
||||
.. note::
|
||||
This function does not support exotic mappers that use multiple tables
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
.. seealso:: :func:`merge_references`
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
if foreign_keys is None:
|
||||
foreign_keys = get_referencing_foreign_keys(obj)
|
||||
|
||||
session = object_session(obj)
|
||||
|
||||
chain = QueryChain([])
|
||||
classes = _get_class_registry(obj.__class__)
|
||||
|
||||
for table, keys in group_foreign_keys(foreign_keys):
|
||||
keys = list(keys)
|
||||
for class_ in classes.values():
|
||||
try:
|
||||
mapper = sa.inspect(class_)
|
||||
except NoInspectionAvailable:
|
||||
continue
|
||||
parent_mapper = mapper.inherits
|
||||
if (
|
||||
table in mapper.tables and
|
||||
not (parent_mapper and table in parent_mapper.tables)
|
||||
):
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(*_get_criteria(keys, class_, obj))
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def _get_criteria(keys, class_, obj):
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint in visited_constraints:
|
||||
continue
|
||||
visited_constraints.append(key.constraint)
|
||||
|
||||
subcriteria = []
|
||||
for index, column in enumerate(key.constraint.columns):
|
||||
foreign_column = (
|
||||
key.constraint.elements[index].column
|
||||
)
|
||||
subcriteria.append(
|
||||
getattr(class_, get_column_key(class_, column)) ==
|
||||
getattr(
|
||||
obj,
|
||||
sa.inspect(type(obj))
|
||||
.get_property_by_column(
|
||||
foreign_column
|
||||
).key
|
||||
)
|
||||
)
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
return criteria
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
bind = getattr(metadata, 'bind', None)
|
||||
if bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload_with=bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not has_index(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def get_fk_constraint_for_columns(table, *columns):
|
||||
for constraint in table.constraints:
|
||||
if list(constraint.columns.values()) == list(columns):
|
||||
return constraint
|
|
@ -0,0 +1,112 @@
|
|||
import contextlib
|
||||
import datetime
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def create_mock_engine(bind, stream=None):
|
||||
"""Create a mock SQLAlchemy engine from the passed engine or bind URL.
|
||||
|
||||
:param bind: A SQLAlchemy engine or bind URL to mock.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
if not isinstance(bind, str):
|
||||
bind_url = str(bind.url)
|
||||
|
||||
else:
|
||||
bind_url = bind
|
||||
|
||||
if stream is not None:
|
||||
|
||||
def dump(sql, *args, **kwargs):
|
||||
|
||||
class Compiler(type(sql._compiler(engine.dialect))):
|
||||
|
||||
def visit_bindparam(self, bindparam, *args, **kwargs):
|
||||
return self.render_literal_value(
|
||||
bindparam.value, bindparam.type)
|
||||
|
||||
def render_literal_value(self, value, type_):
|
||||
if isinstance(value, int):
|
||||
return str(value)
|
||||
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
|
||||
return super().render_literal_value(
|
||||
value, type_)
|
||||
|
||||
text = str(Compiler(engine.dialect, sql).process(sql))
|
||||
text = re.sub(r'\n+', '\n', text)
|
||||
text = text.strip('\n').strip()
|
||||
|
||||
stream.write('\n%s;' % text)
|
||||
|
||||
else:
|
||||
def dump(*args, **kw):
|
||||
return None
|
||||
|
||||
try:
|
||||
engine = sa.create_mock_engine(bind_url, executor=dump)
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
engine = sa.create_engine(bind_url, strategy='mock', executor=dump)
|
||||
return engine
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_engine(engine, stream=None):
|
||||
"""Mocks out the engine specified in the passed bind expression.
|
||||
|
||||
Note this function is meant for convenience and protected usage. Do NOT
|
||||
blindly pass user input to this function as it uses exec.
|
||||
|
||||
:param engine: A python expression that represents the engine to mock.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
# Create a stream if not present.
|
||||
|
||||
if stream is None:
|
||||
stream = io.StringIO()
|
||||
|
||||
# Navigate the stack and find the calling frame that allows the
|
||||
# expression to execute.
|
||||
|
||||
for frame in inspect.stack()[1:]:
|
||||
|
||||
try:
|
||||
frame = frame[0]
|
||||
expression = '__target = %s' % engine
|
||||
exec(expression, frame.f_globals, frame.f_locals)
|
||||
target = frame.f_locals['__target']
|
||||
break
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
raise ValueError('Not a valid python expression', engine)
|
||||
|
||||
# Evaluate the expression and get the target engine.
|
||||
|
||||
frame.f_locals['__mock'] = create_mock_engine(target, stream)
|
||||
|
||||
# Replace the target with our mock.
|
||||
|
||||
exec('%s = __mock' % engine, frame.f_globals, frame.f_locals)
|
||||
|
||||
# Give control back.
|
||||
|
||||
yield stream
|
||||
|
||||
# Put the target engine back.
|
||||
|
||||
frame.f_locals['__target'] = target
|
||||
exec('%s = __target' % engine, frame.f_globals, frame.f_locals)
|
||||
exec('del __target', frame.f_globals, frame.f_locals)
|
||||
exec('del __mock', frame.f_globals, frame.f_locals)
|
|
@ -0,0 +1,904 @@
|
|||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from operator import attrgetter
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.exc import UnmappedInstanceError
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
|
||||
except ImportError: # SQLAlchemy <1.4
|
||||
from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
|
||||
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
|
||||
from ..utils import is_sequence
|
||||
|
||||
|
||||
def get_class_by_table(base, table, data=None):
|
||||
"""
|
||||
Return declarative class associated with given table. If no class is found
|
||||
this function returns `None`. If multiple classes were found (polymorphic
|
||||
cases) additional `data` parameter can be given to hint which class
|
||||
to return.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
|
||||
get_class_by_table(Base, User.__table__) # User class
|
||||
|
||||
|
||||
This function also supports models using single table inheritance.
|
||||
Additional data paratemer should be provided in these case.
|
||||
|
||||
::
|
||||
|
||||
class Entity(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
type = sa.Column(sa.String)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
|
||||
class User(Entity):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
|
||||
|
||||
# Entity class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
|
||||
|
||||
# User class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
|
||||
|
||||
|
||||
:param base: Declarative model base
|
||||
:param table: SQLAlchemy Table object
|
||||
:param data: Data row to determine the class in polymorphic scenarios
|
||||
:return: Declarative class or None.
|
||||
"""
|
||||
found_classes = {
|
||||
c for c in _get_class_registry(base).values()
|
||||
if hasattr(c, '__table__') and c.__table__ is table
|
||||
}
|
||||
if len(found_classes) > 1:
|
||||
if not data:
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{}'. "
|
||||
"Please provide data parameter for this function to be able "
|
||||
"to determine polymorphic scenarios.".format(
|
||||
table.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
for cls in found_classes:
|
||||
mapper = sa.inspect(cls)
|
||||
polymorphic_on = mapper.polymorphic_on.name
|
||||
if polymorphic_on in data:
|
||||
if data[polymorphic_on] == mapper.polymorphic_identity:
|
||||
return cls
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{}'. Given "
|
||||
"data row does not match any polymorphic identity of the "
|
||||
"found classes.".format(
|
||||
table.name
|
||||
)
|
||||
)
|
||||
elif found_classes:
|
||||
return found_classes.pop()
|
||||
return None
|
||||
|
||||
|
||||
def get_type(expr):
|
||||
"""
|
||||
Return the associated type with given Column, InstrumentedAttribute,
|
||||
ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
|
||||
|
||||
For constructs wrapping columns this is the column type. For relationships
|
||||
this function returns the relationship mapper class.
|
||||
|
||||
:param expr:
|
||||
SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
|
||||
similar SA construct.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
|
||||
get_type(User.__table__.c.name) # sa.String()
|
||||
get_type(User.name) # sa.String()
|
||||
get_type(User.name.property) # sa.String()
|
||||
|
||||
get_type(Article.author) # User
|
||||
|
||||
|
||||
.. versionadded: 0.30.9
|
||||
"""
|
||||
if hasattr(expr, 'type'):
|
||||
return expr.type
|
||||
elif isinstance(expr, InstrumentedAttribute):
|
||||
expr = expr.property
|
||||
|
||||
if isinstance(expr, ColumnProperty):
|
||||
return expr.columns[0].type
|
||||
elif isinstance(expr, RelationshipProperty):
|
||||
return expr.mapper.class_
|
||||
raise TypeError("Couldn't inspect type.")
|
||||
|
||||
|
||||
def cast_if(expression, type_):
|
||||
"""
|
||||
Produce a CAST expression but only if given expression is not of given type
|
||||
already.
|
||||
|
||||
Assume we have a model with two fields id (Integer) and name (String).
|
||||
|
||||
::
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import cast_if
|
||||
|
||||
|
||||
cast_if(User.id, sa.Integer) # "user".id
|
||||
cast_if(User.name, sa.String) # "user".name
|
||||
cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
|
||||
|
||||
|
||||
This function supports scalar values as well.
|
||||
|
||||
::
|
||||
|
||||
cast_if(1, sa.Integer) # 1
|
||||
cast_if('text', sa.String) # 'text'
|
||||
cast_if(1, sa.String) # CAST(1 AS TEXT)
|
||||
|
||||
|
||||
:param expression:
|
||||
A SQL expression, such as a ColumnElement expression or a Python string
|
||||
which will be coerced into a bound literal value.
|
||||
:param type_:
|
||||
A TypeEngine class or instance indicating the type to which the CAST
|
||||
should apply.
|
||||
|
||||
.. versionadded: 0.30.14
|
||||
"""
|
||||
try:
|
||||
expr_type = get_type(expression)
|
||||
except TypeError:
|
||||
expr_type = expression
|
||||
check_type = type_().python_type
|
||||
else:
|
||||
check_type = type_
|
||||
|
||||
return (
|
||||
sa.cast(expression, type_)
|
||||
if not isinstance(expr_type, check_type)
|
||||
else expression
|
||||
)
|
||||
|
||||
|
||||
def get_column_key(model, column):
|
||||
"""
|
||||
Return the key for given column in given model.
|
||||
|
||||
:param model: SQLAlchemy declarative model object
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.String)
|
||||
|
||||
|
||||
get_column_key(User, User.__table__.c._name) # 'name'
|
||||
|
||||
.. versionadded: 0.26.5
|
||||
|
||||
.. versionchanged: 0.27.11
|
||||
Throws UnmappedColumnError instead of ValueError when no property was
|
||||
found for given column. This is consistent with how SQLAlchemy works.
|
||||
"""
|
||||
mapper = sa.inspect(model)
|
||||
try:
|
||||
return mapper.get_property_by_column(column).key
|
||||
except sa.orm.exc.UnmappedColumnError:
|
||||
for key, c in mapper.columns.items():
|
||||
if c.name == column.name and c.table is column.table:
|
||||
return key
|
||||
raise sa.orm.exc.UnmappedColumnError(
|
||||
'No column %s is configured on mapper %s...' %
|
||||
(column, mapper)
|
||||
)
|
||||
|
||||
|
||||
def get_mapper(mixed):
|
||||
"""
|
||||
Return related SQLAlchemy Mapper for given SQLAlchemy object.
|
||||
|
||||
:param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import get_mapper
|
||||
|
||||
|
||||
get_mapper(User)
|
||||
|
||||
get_mapper(User())
|
||||
|
||||
get_mapper(User.__table__)
|
||||
|
||||
get_mapper(User.__mapper__)
|
||||
|
||||
get_mapper(sa.orm.aliased(User))
|
||||
|
||||
get_mapper(sa.orm.aliased(User.__table__))
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: if multiple mappers were found for given argument
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
if isinstance(mixed, _MapperEntity):
|
||||
mixed = mixed.expr
|
||||
elif isinstance(mixed, sa.Column):
|
||||
mixed = mixed.table
|
||||
elif isinstance(mixed, _ColumnEntity):
|
||||
mixed = mixed.expr
|
||||
|
||||
if isinstance(mixed, sa.orm.Mapper):
|
||||
return mixed
|
||||
if isinstance(mixed, sa.orm.util.AliasedClass):
|
||||
return sa.inspect(mixed).mapper
|
||||
if isinstance(mixed, sa.sql.selectable.Alias):
|
||||
mixed = mixed.element
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.mapper
|
||||
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
|
||||
mixed = mixed.class_
|
||||
if isinstance(mixed, sa.Table):
|
||||
if hasattr(mapperlib, '_all_registries'):
|
||||
all_mappers = set()
|
||||
for mapper_registry in mapperlib._all_registries():
|
||||
all_mappers.update(mapper_registry.mappers)
|
||||
else: # SQLAlchemy <1.4
|
||||
all_mappers = mapperlib._mapper_registry
|
||||
mappers = [
|
||||
mapper for mapper in all_mappers
|
||||
if mixed in mapper.tables
|
||||
]
|
||||
if len(mappers) > 1:
|
||||
raise ValueError(
|
||||
"Multiple mappers found for table '%s'." % mixed.name
|
||||
)
|
||||
elif not mappers:
|
||||
raise ValueError(
|
||||
"Could not get mapper for table '%s'." % mixed.name
|
||||
)
|
||||
else:
|
||||
return mappers[0]
|
||||
if not isclass(mixed):
|
||||
mixed = type(mixed)
|
||||
return sa.inspect(mixed)
|
||||
|
||||
|
||||
def get_bind(obj):
|
||||
"""
|
||||
Return the bind for given SQLAlchemy Engine / Connection / declarative
|
||||
model object.
|
||||
|
||||
:param obj: SQLAlchemy Engine / Connection / declarative model object
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import get_bind
|
||||
|
||||
|
||||
get_bind(session) # Connection object
|
||||
|
||||
get_bind(user)
|
||||
|
||||
"""
|
||||
if hasattr(obj, 'bind'):
|
||||
conn = obj.bind
|
||||
else:
|
||||
try:
|
||||
conn = object_session(obj).bind
|
||||
except UnmappedInstanceError:
|
||||
conn = obj
|
||||
|
||||
if not hasattr(conn, 'execute'):
|
||||
raise TypeError(
|
||||
'This method accepts only Session, Engine, Connection and '
|
||||
'declarative model objects.'
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
def get_primary_keys(mixed):
|
||||
"""
|
||||
Return an OrderedDict of all primary keys for given Table object,
|
||||
declarative class or declarative class instance.
|
||||
|
||||
:param mixed:
|
||||
SA Table object, SA declarative class or SA declarative class instance
|
||||
|
||||
::
|
||||
|
||||
get_primary_keys(User)
|
||||
|
||||
get_primary_keys(User())
|
||||
|
||||
get_primary_keys(User.__table__)
|
||||
|
||||
get_primary_keys(User.__mapper__)
|
||||
|
||||
get_primary_keys(sa.orm.aliased(User))
|
||||
|
||||
get_primary_keys(sa.orm.aliased(User.__table__))
|
||||
|
||||
|
||||
.. versionchanged: 0.25.3
|
||||
Made the function return an ordered dictionary instead of generator.
|
||||
This change was made to support primary key aliases.
|
||||
|
||||
Renamed this function to 'get_primary_keys', formerly 'primary_keys'
|
||||
|
||||
.. seealso:: :func:`get_columns`
|
||||
"""
|
||||
return OrderedDict(
|
||||
(
|
||||
(key, column) for key, column in get_columns(mixed).items()
|
||||
if column.primary_key
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_tables(mixed):
|
||||
"""
|
||||
Return a set of tables associated with given SQLAlchemy object.
|
||||
|
||||
Let's say we have three classes which use joined table inheritance
|
||||
TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
|
||||
|
||||
::
|
||||
|
||||
get_tables(Article) # set([Table('article', ...), Table('text_item')])
|
||||
|
||||
get_tables(Article())
|
||||
|
||||
get_tables(Article.__mapper__)
|
||||
|
||||
|
||||
If the TextItem entity is using with_polymorphic='*' then this function
|
||||
returns all child tables (article and blog_post) as well.
|
||||
|
||||
::
|
||||
|
||||
|
||||
get_tables(TextItem) # set([Table('text_item', ...)], ...])
|
||||
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
|
||||
:param mixed:
|
||||
SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
|
||||
a SA Alias object wrapping any of these objects.
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
return [mixed]
|
||||
elif isinstance(mixed, sa.Column):
|
||||
return [mixed.table]
|
||||
elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
|
||||
return mixed.parent.tables
|
||||
elif isinstance(mixed, _ColumnEntity):
|
||||
mixed = mixed.expr
|
||||
|
||||
mapper = get_mapper(mixed)
|
||||
|
||||
polymorphic_mappers = get_polymorphic_mappers(mapper)
|
||||
if polymorphic_mappers:
|
||||
tables = sum((m.tables for m in polymorphic_mappers), [])
|
||||
else:
|
||||
tables = mapper.tables
|
||||
return tables
|
||||
|
||||
|
||||
def get_columns(mixed):
|
||||
"""
|
||||
Return a collection of all Column objects for given SQLAlchemy
|
||||
object.
|
||||
|
||||
The type of the collection depends on the type of the object to return the
|
||||
columns from.
|
||||
|
||||
::
|
||||
|
||||
get_columns(User)
|
||||
|
||||
get_columns(User())
|
||||
|
||||
get_columns(User.__table__)
|
||||
|
||||
get_columns(User.__mapper__)
|
||||
|
||||
get_columns(sa.orm.aliased(User))
|
||||
|
||||
get_columns(sa.orm.alised(User.__table__))
|
||||
|
||||
|
||||
:param mixed:
|
||||
SA Table object, SA Mapper, SA declarative class, SA declarative class
|
||||
instance or an alias of any of these objects
|
||||
"""
|
||||
if isinstance(mixed, sa.sql.selectable.Selectable):
|
||||
try:
|
||||
return mixed.selected_columns
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
return mixed.c
|
||||
if isinstance(mixed, sa.orm.util.AliasedClass):
|
||||
return sa.inspect(mixed).mapper.columns
|
||||
if isinstance(mixed, sa.orm.Mapper):
|
||||
return mixed.columns
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return mixed.property.columns
|
||||
if isinstance(mixed, ColumnProperty):
|
||||
return mixed.columns
|
||||
if isinstance(mixed, sa.Column):
|
||||
return [mixed]
|
||||
if not isclass(mixed):
|
||||
mixed = mixed.__class__
|
||||
return sa.inspect(mixed).columns
|
||||
|
||||
|
||||
def table_name(obj):
|
||||
"""
|
||||
Return table name of given target, declarative class or the
|
||||
table name where the declarative attribute is bound to.
|
||||
"""
|
||||
class_ = getattr(obj, 'class_', obj)
|
||||
|
||||
try:
|
||||
return class_.__tablename__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return class_.__table__.name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def getattrs(obj, attrs):
|
||||
return map(partial(getattr, obj), attrs)
|
||||
|
||||
|
||||
def quote(mixed, ident):
|
||||
"""
|
||||
Conditionally quote an identifier.
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import quote
|
||||
|
||||
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
|
||||
quote(engine, 'order')
|
||||
# '"order"'
|
||||
|
||||
quote(engine, 'some_other_identifier')
|
||||
# 'some_other_identifier'
|
||||
|
||||
|
||||
:param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
|
||||
:param ident: identifier to conditionally quote
|
||||
"""
|
||||
if isinstance(mixed, Dialect):
|
||||
dialect = mixed
|
||||
else:
|
||||
dialect = get_bind(mixed).dialect
|
||||
return dialect.preparer(dialect).quote(ident)
|
||||
|
||||
|
||||
def _get_query_compile_state(query):
|
||||
if hasattr(query, '_compile_state'):
|
||||
return query._compile_state()
|
||||
else: # SQLAlchemy <1.4
|
||||
return query
|
||||
|
||||
|
||||
def get_polymorphic_mappers(mixed):
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.with_polymorphic_mappers
|
||||
else:
|
||||
return mixed.polymorphic_map.values()
|
||||
|
||||
|
||||
def get_descriptor(entity, attr):
|
||||
mapper = sa.inspect(entity)
|
||||
|
||||
for key, descriptor in get_all_descriptors(mapper).items():
|
||||
if attr == key:
|
||||
prop = (
|
||||
descriptor.property
|
||||
if hasattr(descriptor, 'property')
|
||||
else None
|
||||
)
|
||||
if isinstance(prop, ColumnProperty):
|
||||
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||
for c in mapper.selectable.c:
|
||||
if c.key == attr:
|
||||
return c
|
||||
else:
|
||||
# If the property belongs to a class that uses
|
||||
# polymorphic inheritance we have to take into account
|
||||
# situations where the attribute exists in child class
|
||||
# but not in parent class.
|
||||
return getattr(prop.parent.class_, attr)
|
||||
else:
|
||||
# Handle synonyms, relationship properties and hybrid
|
||||
# properties
|
||||
|
||||
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||
return getattr(entity, attr)
|
||||
try:
|
||||
return getattr(mapper.class_, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_all_descriptors(expr):
|
||||
if isinstance(expr, sa.sql.selectable.Selectable):
|
||||
return expr.c
|
||||
insp = sa.inspect(expr)
|
||||
try:
|
||||
polymorphic_mappers = get_polymorphic_mappers(insp)
|
||||
except sa.exc.NoInspectionAvailable:
|
||||
return get_mapper(expr).all_orm_descriptors
|
||||
else:
|
||||
attrs = dict(get_mapper(expr).all_orm_descriptors)
|
||||
for submapper in polymorphic_mappers:
|
||||
for key, descriptor in submapper.all_orm_descriptors.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = descriptor
|
||||
return attrs
|
||||
|
||||
|
||||
def get_hybrid_properties(model):
|
||||
"""
|
||||
Returns a dictionary of hybrid property keys and hybrid properties for
|
||||
given SQLAlchemy declarative model / mapper.
|
||||
|
||||
|
||||
Consider the following model
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def lowercase_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
@lowercase_name.expression
|
||||
def lowercase_name(cls):
|
||||
return sa.func.lower(cls.name)
|
||||
|
||||
|
||||
You can now easily get a list of all hybrid property names
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_hybrid_properties
|
||||
|
||||
|
||||
get_hybrid_properties(Category).keys() # ['lowercase_name']
|
||||
|
||||
|
||||
This function also supports aliased classes
|
||||
|
||||
::
|
||||
|
||||
|
||||
get_hybrid_properties(
|
||||
sa.orm.aliased(Category)
|
||||
).keys() # ['lowercase_name']
|
||||
|
||||
|
||||
.. versionchanged: 0.26.7
|
||||
This function now returns a dictionary instead of generator
|
||||
|
||||
.. versionchanged: 0.30.15
|
||||
Added support for aliased classes
|
||||
|
||||
:param model: SQLAlchemy declarative model or mapper
|
||||
"""
|
||||
return {
|
||||
key: prop
|
||||
for key, prop in get_mapper(model).all_orm_descriptors.items()
|
||||
if isinstance(prop, hybrid_property)
|
||||
}
|
||||
|
||||
|
||||
def get_declarative_base(model):
|
||||
"""
|
||||
Returns the declarative base for given model class.
|
||||
|
||||
:param model: SQLAlchemy declarative model
|
||||
"""
|
||||
for parent in model.__bases__:
|
||||
try:
|
||||
parent.metadata
|
||||
return get_declarative_base(parent)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
def getdotattr(obj_or_class, dot_path, condition=None):
|
||||
"""
|
||||
Allow dot-notated strings to be passed to `getattr`.
|
||||
|
||||
::
|
||||
|
||||
getdotattr(SubSection, 'section.document')
|
||||
|
||||
getdotattr(subsection, 'section.document')
|
||||
|
||||
|
||||
:param obj_or_class: Any object or class
|
||||
:param dot_path: Attribute path with dot mark as separator
|
||||
"""
|
||||
last = obj_or_class
|
||||
|
||||
for path in str(dot_path).split('.'):
|
||||
getter = attrgetter(path)
|
||||
|
||||
if is_sequence(last):
|
||||
tmp = []
|
||||
for element in last:
|
||||
value = getter(element)
|
||||
if is_sequence(value):
|
||||
tmp.extend(value)
|
||||
else:
|
||||
tmp.append(value)
|
||||
last = tmp
|
||||
elif isinstance(last, InstrumentedAttribute):
|
||||
last = getter(last.property.mapper.class_)
|
||||
elif last is None:
|
||||
return None
|
||||
else:
|
||||
last = getter(last)
|
||||
if condition is not None:
|
||||
if is_sequence(last):
|
||||
last = [v for v in last if condition(v)]
|
||||
else:
|
||||
if not condition(last):
|
||||
return None
|
||||
|
||||
return last
|
||||
|
||||
|
||||
def is_deleted(obj):
|
||||
return obj in sa.orm.object_session(obj).deleted
|
||||
|
||||
|
||||
def has_changes(obj, attrs=None, exclude=None):
|
||||
"""
|
||||
Simple shortcut function for checking if given attributes of given
|
||||
declarative model object have changed during the session. Without
|
||||
parameters this checks if given object has any modificiations. Additionally
|
||||
exclude parameter can be given to check if given object has any changes
|
||||
in any attributes other than the ones given in exclude.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import has_changes
|
||||
|
||||
|
||||
user = User()
|
||||
|
||||
has_changes(user, 'name') # False
|
||||
|
||||
user.name = 'someone'
|
||||
|
||||
has_changes(user, 'name') # True
|
||||
|
||||
has_changes(user) # True
|
||||
|
||||
|
||||
You can check multiple attributes as well.
|
||||
::
|
||||
|
||||
|
||||
has_changes(user, ['age']) # True
|
||||
|
||||
has_changes(user, ['name', 'age']) # True
|
||||
|
||||
|
||||
This function also supports excluding certain attributes.
|
||||
|
||||
::
|
||||
|
||||
has_changes(user, exclude=['name']) # False
|
||||
|
||||
has_changes(user, exclude=['age']) # True
|
||||
|
||||
.. versionchanged: 0.26.6
|
||||
Added support for multiple attributes and exclude parameter.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param attrs: Names of the attributes
|
||||
:param exclude: Names of the attributes to exclude
|
||||
"""
|
||||
if attrs:
|
||||
if isinstance(attrs, str):
|
||||
return (
|
||||
sa.inspect(obj)
|
||||
.attrs
|
||||
.get(attrs)
|
||||
.history
|
||||
.has_changes()
|
||||
)
|
||||
else:
|
||||
return any(has_changes(obj, attr) for attr in attrs)
|
||||
else:
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
return any(
|
||||
attr.history.has_changes()
|
||||
for key, attr in sa.inspect(obj).attrs.items()
|
||||
if key not in exclude
|
||||
)
|
||||
|
||||
|
||||
def is_loaded(obj, prop):
|
||||
"""
|
||||
Return whether or not given property of given object has been loaded.
|
||||
|
||||
::
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
content = sa.orm.deferred(sa.Column(sa.String))
|
||||
|
||||
|
||||
article = session.query(Article).get(5)
|
||||
|
||||
# name gets loaded since its not a deferred property
|
||||
assert is_loaded(article, 'name')
|
||||
|
||||
# content has not yet been loaded since its a deferred property
|
||||
assert not is_loaded(article, 'content')
|
||||
|
||||
|
||||
.. versionadded: 0.27.8
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param prop: Name of the property or InstrumentedAttribute
|
||||
"""
|
||||
return prop not in sa.inspect(obj).unloaded
|
||||
|
||||
|
||||
def identity(obj_or_class):
|
||||
"""
|
||||
Return the identity of given sqlalchemy declarative model class or instance
|
||||
as a tuple. This differs from obj._sa_instance_state.identity in a way that
|
||||
it always returns the identity even if object is still in transient state (
|
||||
new object that is not yet persisted into database). Also for classes it
|
||||
returns the identity attributes.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy_utils import identity
|
||||
|
||||
|
||||
user = User(name='John Matrix')
|
||||
session.add(user)
|
||||
identity(user) # None
|
||||
inspect(user).identity # None
|
||||
|
||||
session.flush() # User now has id but is still in transient state
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # None
|
||||
|
||||
session.commit()
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # (1, )
|
||||
|
||||
|
||||
You can also use identity for classes::
|
||||
|
||||
|
||||
identity(User) # (User.id, )
|
||||
|
||||
.. versionadded: 0.21.0
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
"""
|
||||
return tuple(
|
||||
getattr(obj_or_class, column_key)
|
||||
for column_key in get_primary_keys(obj_or_class).keys()
|
||||
)
|
||||
|
||||
|
||||
def naturally_equivalent(obj, obj2):
|
||||
"""
|
||||
Returns whether or not two given SQLAlchemy declarative instances are
|
||||
naturally equivalent (all their non primary key properties are equivalent).
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import naturally_equivalent
|
||||
|
||||
|
||||
user = User(name='someone')
|
||||
user2 = User(name='someone')
|
||||
|
||||
user == user2 # False
|
||||
|
||||
naturally_equivalent(user, user2) # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param obj2: SQLAlchemy declarative model object to compare with `obj`
|
||||
"""
|
||||
for column_key, column in sa.inspect(obj.__class__).columns.items():
|
||||
if column.primary_key:
|
||||
continue
|
||||
|
||||
if not (getattr(obj, column_key) == getattr(obj2, column_key)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_class_registry(class_):
|
||||
try:
|
||||
return class_.registry._class_registry
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
return class_._decl_class_registry
|
|
@ -0,0 +1,75 @@
|
|||
import inspect
|
||||
import io
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .mock import create_mock_engine
|
||||
from .orm import _get_query_compile_state
|
||||
|
||||
|
||||
def render_expression(expression, bind, stream=None):
|
||||
"""Generate a SQL expression from the passed python expression.
|
||||
|
||||
Only the global variable, `engine`, is available for use in the
|
||||
expression. Additional local variables may be passed in the context
|
||||
parameter.
|
||||
|
||||
Note this function is meant for convenience and protected usage. Do NOT
|
||||
blindly pass user input to this function as it uses exec.
|
||||
|
||||
:param bind: A SQLAlchemy engine or bind URL.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
# Create a stream if not present.
|
||||
|
||||
if stream is None:
|
||||
stream = io.StringIO()
|
||||
|
||||
engine = create_mock_engine(bind, stream)
|
||||
|
||||
# Navigate the stack and find the calling frame that allows the
|
||||
# expression to execuate.
|
||||
|
||||
for frame in inspect.stack()[1:]:
|
||||
try:
|
||||
frame = frame[0]
|
||||
local = dict(frame.f_locals)
|
||||
local['engine'] = engine
|
||||
exec(expression, frame.f_globals, local)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Not a valid python expression', engine)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def render_statement(statement, bind=None):
|
||||
"""
|
||||
Generate an SQL expression string with bound parameters rendered inline
|
||||
for the given SQLAlchemy statement.
|
||||
|
||||
:param statement: SQLAlchemy Query object.
|
||||
:param bind:
|
||||
Optional SQLAlchemy bind, if None uses the bind of the given query
|
||||
object.
|
||||
"""
|
||||
|
||||
if isinstance(statement, sa.orm.query.Query):
|
||||
if bind is None:
|
||||
bind = statement.session.get_bind(
|
||||
_get_query_compile_state(statement)._mapper_zero()
|
||||
)
|
||||
|
||||
statement = statement.statement
|
||||
|
||||
elif bind is None:
|
||||
bind = statement.bind
|
||||
|
||||
stream = io.StringIO()
|
||||
engine = create_mock_engine(bind.engine, stream=stream)
|
||||
engine.execute(statement)
|
||||
|
||||
return stream.getvalue()
|
|
@ -0,0 +1,74 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
from .database import has_unique_index
|
||||
from .orm import _get_query_compile_state, get_tables
|
||||
|
||||
|
||||
def make_order_by_deterministic(query):
|
||||
"""
|
||||
Make query order by deterministic (if it isn't already). Order by is
|
||||
considered deterministic if it contains column that is unique index (
|
||||
either it is a primary key or has a unique index). Many times it is design
|
||||
flaw to order by queries in nondeterministic manner.
|
||||
|
||||
Consider a User model with three fields: id (primary key), favorite color
|
||||
and email (unique).::
|
||||
|
||||
|
||||
from sqlalchemy_utils import make_order_by_deterministic
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.favorite_color)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id'
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.email)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".email'
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.id)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".id'
|
||||
|
||||
|
||||
.. versionadded: 0.27.1
|
||||
"""
|
||||
order_by_func = sa.asc
|
||||
|
||||
try:
|
||||
order_by_clauses = query._order_by_clauses
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
order_by_clauses = query._order_by
|
||||
if not order_by_clauses:
|
||||
column = None
|
||||
else:
|
||||
order_by = order_by_clauses[0]
|
||||
if isinstance(order_by, sa.sql.elements._label_reference):
|
||||
order_by = order_by.element
|
||||
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
||||
if order_by.modifier == sa.sql.operators.desc_op:
|
||||
order_by_func = sa.desc
|
||||
else:
|
||||
order_by_func = sa.asc
|
||||
column = list(order_by.get_children())[0]
|
||||
else:
|
||||
column = order_by
|
||||
|
||||
# Skip queries that are ordered by an already deterministic column
|
||||
if isinstance(column, sa.Column):
|
||||
try:
|
||||
if has_unique_index(column):
|
||||
return query
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
base_table = get_tables(_get_query_compile_state(query)._entities[0])[0]
|
||||
query = query.order_by(
|
||||
*(order_by_func(c) for c in base_table.c if c.primary_key)
|
||||
)
|
||||
return query
|
Loading…
Add table
Add a link
Reference in a new issue