Cleaned up the directories

This commit is contained in:
ComputerTech312 2024-02-19 15:34:25 +01:00
parent f708506d68
commit a683fcffea
1340 changed files with 554582 additions and 6840 deletions

View file

@ -0,0 +1,42 @@
from .database import ( # noqa
create_database,
database_exists,
drop_database,
escape_like,
has_index,
has_unique_index,
is_auto_assigned_date_column,
json_sql,
jsonb_sql
)
from .foreign_keys import ( # noqa
dependent_objects,
get_fk_constraint_for_columns,
get_referencing_foreign_keys,
group_foreign_keys,
merge_references,
non_indexed_foreign_keys
)
from .mock import create_mock_engine, mock_engine # noqa
from .orm import ( # noqa
cast_if,
get_bind,
get_class_by_table,
get_column_key,
get_columns,
get_declarative_base,
get_hybrid_properties,
get_mapper,
get_primary_keys,
get_tables,
get_type,
getdotattr,
has_changes,
identity,
is_loaded,
naturally_equivalent,
quote,
table_name
)
from .render import render_expression, render_statement # noqa
from .sort_query import make_order_by_deterministic # noqa

View file

@ -0,0 +1,659 @@
import itertools
import os
from collections.abc import Mapping, Sequence
from copy import copy
import sqlalchemy as sa
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError, ProgrammingError
from ..utils import starts_with
from .orm import quote
def escape_like(string, escape_char='*'):
"""
Escape the string paremeter used in SQL LIKE expressions.
::
from sqlalchemy_utils import escape_like
query = session.query(User).filter(
User.name.ilike(escape_like('John'))
)
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def json_sql(value, scalars_to_json=True):
"""
Convert python data structures to PostgreSQL specific SQLAlchemy JSON
constructs. This function is extremly useful if you need to build
PostgreSQL JSON on python side.
.. note::
This function needs PostgreSQL >= 9.4
Scalars are converted to to_json SQLAlchemy function objects
::
json_sql(1) # Equals SQL: to_json(1)
json_sql('a') # to_json('a')
Mappings are converted to json_build_object constructs
::
json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5)
Sequences (other than strings) are converted to json_build_array constructs
::
json_sql([1, 2, 3]) # json_build_array(1, 2, 3)
You can also nest these data structures
::
json_sql({'a': [1, 2, 3]})
# json_build_object('a', json_build_array[1, 2, 3])
:param value:
value to be converted to SQLAlchemy PostgreSQL function constructs
"""
scalar_convert = sa.text
if scalars_to_json:
def scalar_convert(a):
return sa.func.to_json(sa.text(a))
if isinstance(value, Mapping):
return sa.func.json_build_object(
*(
json_sql(v, scalars_to_json=False)
for v in itertools.chain(*value.items())
)
)
elif isinstance(value, str):
return scalar_convert(f"'{value}'")
elif isinstance(value, Sequence):
return sa.func.json_build_array(
*(
json_sql(v, scalars_to_json=False)
for v in value
)
)
elif isinstance(value, (int, float)):
return scalar_convert(str(value))
return value
def jsonb_sql(value, scalars_to_jsonb=True):
"""
Convert python data structures to PostgreSQL specific SQLAlchemy JSONB
constructs. This function is extremly useful if you need to build
PostgreSQL JSONB on python side.
.. note::
This function needs PostgreSQL >= 9.4
Scalars are converted to to_jsonb SQLAlchemy function objects
::
jsonb_sql(1) # Equals SQL: to_jsonb(1)
jsonb_sql('a') # to_jsonb('a')
Mappings are converted to jsonb_build_object constructs
::
jsonb_sql({'a': 'c', '2': 5}) # jsonb_build_object('a', 'c', '2', 5)
Sequences (other than strings) converted to jsonb_build_array constructs
::
jsonb_sql([1, 2, 3]) # jsonb_build_array(1, 2, 3)
You can also nest these data structures
::
jsonb_sql({'a': [1, 2, 3]})
# jsonb_build_object('a', jsonb_build_array[1, 2, 3])
:param value:
value to be converted to SQLAlchemy PostgreSQL function constructs
:boolean jsonbb:
Flag to alternatively convert the return with a to_jsonb construct
"""
scalar_convert = sa.text
if scalars_to_jsonb:
def scalar_convert(a):
return sa.func.to_jsonb(sa.text(a))
if isinstance(value, Mapping):
return sa.func.jsonb_build_object(
*(
jsonb_sql(v, scalars_to_jsonb=False)
for v in itertools.chain(*value.items())
)
)
elif isinstance(value, str):
return scalar_convert(f"'{value}'")
elif isinstance(value, Sequence):
return sa.func.jsonb_build_array(
*(
jsonb_sql(v, scalars_to_jsonb=False)
for v in value
)
)
elif isinstance(value, (int, float)):
return scalar_convert(str(value))
return value
def has_index(column_or_constraint):
"""
Return whether or not given column or the columns of given foreign key
constraint have an index. A column has an index if it has a single column
index or it is the first column in compound column index.
A foreign key constraint has an index if the constraint columns are the
first columns in compound column index.
:param column_or_constraint:
SQLAlchemy Column object or SA ForeignKeyConstraint object
.. versionadded: 0.26.2
.. versionchanged: 0.30.18
Added support for foreign key constaints.
::
from sqlalchemy_utils import has_index
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
is_published = sa.Column(sa.Boolean, index=True)
is_deleted = sa.Column(sa.Boolean)
is_archived = sa.Column(sa.Boolean)
__table_args__ = (
sa.Index('my_index', is_deleted, is_archived),
)
table = Article.__table__
has_index(table.c.is_published) # True
has_index(table.c.is_deleted) # True
has_index(table.c.is_archived) # False
Also supports primary key indexes
::
from sqlalchemy_utils import has_index
class ArticleTranslation(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
locale = sa.Column(sa.String(10), primary_key=True)
title = sa.Column(sa.String(100))
table = ArticleTranslation.__table__
has_index(table.c.locale) # False
has_index(table.c.id) # True
This function supports foreign key constraints as well
::
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
author_last_name = sa.Column(sa.Unicode(255))
__table_args__ = (
sa.ForeignKeyConstraint(
[author_first_name, author_last_name],
[User.first_name, User.last_name]
),
sa.Index(
'my_index',
author_first_name,
author_last_name
)
)
table = Article.__table__
constraint = list(table.foreign_keys)[0].constraint
has_index(constraint) # True
"""
table = column_or_constraint.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
primary_keys = table.primary_key.columns.values()
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
columns = list(column_or_constraint.columns.values())
else:
columns = [column_or_constraint]
return (
(primary_keys and starts_with(primary_keys, columns)) or
any(
starts_with(index.columns.values(), columns)
for index in table.indexes
)
)
def has_unique_index(column_or_constraint):
"""
Return whether or not given column or given foreign key constraint has a
unique index.
A column has a unique index if it has a single column primary key index or
it has a single column UniqueConstraint.
A foreign key constraint has a unique index if the columns of the
constraint are the same as the columns of table primary key or the coluns
of any unique index or any unique constraint of the given table.
:param column: SQLAlchemy Column object
.. versionadded: 0.27.1
.. versionchanged: 0.30.18
Added support for foreign key constaints.
Fixed support for unique indexes (previously only worked for unique
constraints)
::
from sqlalchemy_utils import has_unique_index
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
is_published = sa.Column(sa.Boolean, unique=True)
is_deleted = sa.Column(sa.Boolean)
is_archived = sa.Column(sa.Boolean)
table = Article.__table__
has_unique_index(table.c.is_published) # True
has_unique_index(table.c.is_deleted) # False
has_unique_index(table.c.id) # True
This function supports foreign key constraints as well
::
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
author_last_name = sa.Column(sa.Unicode(255))
__table_args__ = (
sa.ForeignKeyConstraint(
[author_first_name, author_last_name],
[User.first_name, User.last_name]
),
sa.Index(
'my_index',
author_first_name,
author_last_name,
unique=True
)
)
table = Article.__table__
constraint = list(table.foreign_keys)[0].constraint
has_unique_index(constraint) # True
:raises TypeError: if given column does not belong to a Table object
"""
table = column_or_constraint.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
primary_keys = list(table.primary_key.columns.values())
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
columns = list(column_or_constraint.columns.values())
else:
columns = [column_or_constraint]
return (
(columns == primary_keys) or
any(
columns == list(constraint.columns.values())
for constraint in table.constraints
if isinstance(constraint, sa.sql.schema.UniqueConstraint)
) or
any(
columns == list(index.columns.values())
for index in table.indexes
if index.unique
)
)
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
) and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def _set_url_database(url: sa.engine.url.URL, database):
"""Set the database of an engine URL.
:param url: A SQLAlchemy engine URL.
:param database: New database to set.
"""
if hasattr(url, '_replace'):
# Cannot use URL.set() as database may need to be set to None.
ret = url._replace(database=database)
else: # SQLAlchemy <1.4
url = copy(url)
url.database = database
ret = url
assert ret.database == database, ret
return ret
def _get_scalar_result(engine, sql):
with engine.connect() as conn:
return conn.scalar(sql)
def _sqlite_file_exists(database):
if not os.path.isfile(database) or os.path.getsize(database) < 100:
return False
with open(database, 'rb') as f:
header = f.read(100)
return header[:16] == b'SQLite format 3\x00'
def database_exists(url):
"""Check if a database exists.
:param url: A SQLAlchemy engine URL.
Performs backend-specific testing to quickly determine if a database
exists on the server. ::
database_exists('postgresql://postgres@localhost/name') #=> False
create_database('postgresql://postgres@localhost/name')
database_exists('postgresql://postgres@localhost/name') #=> True
Supports checking against a constructed URL as well. ::
engine = create_engine('postgresql://postgres@localhost/name')
database_exists(engine.url) #=> False
create_database(engine.url)
database_exists(engine.url) #=> True
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
engine = None
try:
if dialect_name == 'postgresql':
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, 'postgres', 'template1', 'template0', None):
url = _set_url_database(url, database=db)
engine = sa.create_engine(url)
try:
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
pass
return False
elif dialect_name == 'mysql':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database)
return bool(_get_scalar_result(engine, sa.text(text)))
elif dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
if database:
return database == ':memory:' or _sqlite_file_exists(database)
else:
# The default SQLAlchemy database is in memory, and :memory: is
# not required, thus we should support that use case.
return True
else:
text = 'SELECT 1'
try:
engine = sa.create_engine(url)
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
return False
finally:
if engine:
engine.dispose()
def create_database(url, encoding='utf8', template=None):
"""Issue the appropriate CREATE DATABASE statement.
:param url: A SQLAlchemy engine URL.
:param encoding: The encoding to create the database as.
:param template:
The name of the template from which to create the new database. At the
moment only supported by PostgreSQL driver.
To create a database, you can pass a simple URL that would have
been passed to ``create_engine``. ::
create_database('postgresql://postgres@localhost/name')
You may also pass the url from an existing engine. ::
create_database(engine.url)
Has full support for mysql, postgres, and sqlite. In theory,
other database engines should be supported.
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \
or (dialect_name == 'postgresql' and dialect_driver in {
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}):
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
else:
engine = sa.create_engine(url)
if dialect_name == 'postgresql':
if not template:
template = 'template1'
with engine.begin() as conn:
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database),
encoding,
quote(conn, template)
)
conn.execute(sa.text(text))
elif dialect_name == 'mysql':
with engine.begin() as conn:
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database),
encoding
)
conn.execute(sa.text(text))
elif dialect_name == 'sqlite' and database != ':memory:':
if database:
with engine.begin() as conn:
conn.execute(sa.text('CREATE TABLE DB(id int)'))
conn.execute(sa.text('DROP TABLE DB'))
else:
with engine.begin() as conn:
text = f'CREATE DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
engine.dispose()
def drop_database(url):
"""Issue the appropriate DROP DATABASE statement.
:param url: A SQLAlchemy engine URL.
Works similar to the :ref:`create_database` method in that both url text
and a constructed url are accepted. ::
drop_database('postgresql://postgres@localhost/name')
drop_database(engine.url)
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
if dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}:
engine = sa.create_engine(url, connect_args={'autocommit': True})
elif dialect_name == 'postgresql' and dialect_driver in {
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}:
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
else:
engine = sa.create_engine(url)
if dialect_name == 'sqlite' and database != ':memory:':
if database:
os.remove(database)
elif dialect_name == 'postgresql':
with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = (
'pid' if (version >= (9, 2)) else 'procpid'
)
text = '''
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
'''.format(pid_column=pid_column, database=database)
conn.execute(sa.text(text))
# Drop the database.
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
else:
with engine.begin() as conn:
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
engine.dispose()

View file

@ -0,0 +1,350 @@
from collections import defaultdict
from itertools import groupby
import sqlalchemy as sa
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import object_session
from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
from ..query_chain import QueryChain
from .database import has_index
from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
def get_foreign_key_values(fk, obj):
mapper = get_mapper(obj)
return {
fk.constraint.columns.values()[index]:
getattr(obj, element.column.key)
if hasattr(obj, element.column.key)
else getattr(
obj, mapper.get_property_by_column(element.column).key
)
for index, element in enumerate(fk.constraint.elements)
}
def group_foreign_keys(foreign_keys):
"""
Return a groupby iterator that groups given foreign keys by table.
:param foreign_keys: a sequence of foreign keys
::
foreign_keys = get_referencing_foreign_keys(User)
for table, fks in group_foreign_keys(foreign_keys):
# do something
pass
.. seealso:: :func:`get_referencing_foreign_keys`
.. versionadded: 0.26.1
"""
foreign_keys = sorted(
foreign_keys, key=lambda key: key.constraint.table.name
)
return groupby(foreign_keys, lambda key: key.constraint.table)
def get_referencing_foreign_keys(mixed):
"""
Returns referencing foreign keys for given Table object or declarative
class.
:param mixed:
SA Table object or SA declarative class
::
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
get_referencing_foreign_keys(User.__table__)
This function also understands inheritance. This means it returns
all foreign keys that reference any table in the class inheritance tree.
Let's say you have three classes which use joined table inheritance,
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
TextItem.
::
# This will check all foreign keys that reference either article table
# or textitem table.
get_referencing_foreign_keys(Article)
.. seealso:: :func:`get_tables`
"""
if isinstance(mixed, sa.Table):
tables = [mixed]
else:
tables = get_tables(mixed)
referencing_foreign_keys = set()
for table in mixed.metadata.tables.values():
if table not in tables:
for constraint in table.constraints:
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
for fk in constraint.elements:
if any(fk.references(t) for t in tables):
referencing_foreign_keys.add(fk)
return referencing_foreign_keys
def merge_references(from_, to, foreign_keys=None):
"""
Merge the references of an entity into another entity.
Consider the following models::
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(255))
def __repr__(self):
return 'User(name=%r)' % self.name
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(255))
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
Now lets add some data::
john = self.User(name='John')
jack = self.User(name='Jack')
post = self.BlogPost(title='Some title', author=john)
post2 = self.BlogPost(title='Other title', author=jack)
self.session.add_all([
john,
jack,
post,
post2
])
self.session.commit()
If we wanted to merge all John's references to Jack it would be as easy as
::
merge_references(john, jack)
self.session.commit()
post.author # User(name='Jack')
post2.author # User(name='Jack')
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param foreign_keys: A sequence of foreign keys. By default this is None
indicating all referencing foreign keys should be used.
.. seealso: :func:`dependent_objects`
.. versionadded: 0.26.1
.. versionchanged: 0.40.0
Removed possibility for old-style synchronize_session merging. Only
SQL based merging supported for now.
"""
if from_.__tablename__ != to.__tablename__:
raise TypeError('The tables of given arguments do not match.')
session = object_session(from_)
foreign_keys = get_referencing_foreign_keys(from_)
for fk in foreign_keys:
old_values = get_foreign_key_values(fk, from_)
new_values = get_foreign_key_values(fk, to)
criteria = (
getattr(fk.constraint.table.c, key.key) == value
for key, value in old_values.items()
)
query = (
fk.constraint.table.update()
.where(sa.and_(*criteria))
.values(
{key.key: value for key, value in new_values.items()}
)
)
session.execute(query)
def dependent_objects(obj, foreign_keys=None):
"""
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
through all dependent objects for given SQLAlchemy object.
Consider a User object is referenced in various articles and also in
various orders. Getting all these dependent objects is as easy as::
from sqlalchemy_utils import dependent_objects
dependent_objects(user)
If you expect an object to have lots of dependent_objects it might be good
to limit the results::
dependent_objects(user).limit(5)
The common use case is checking for all restrict dependent objects before
deleting parent object and inform the user if there are dependent objects
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
it will lead to nasty IntegrityErrors being raised.
In the following example we delete given user if it doesn't have any
foreign key restricted dependent objects::
from sqlalchemy_utils import get_referencing_foreign_keys
user = session.query(User).get(some_user_id)
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(User)
# On most databases RESTRICT is the default mode hence we
# check for None values also
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
)
if deps:
# Do something to inform the user
pass
else:
session.delete(user)
:param obj: SQLAlchemy declarative model object
:param foreign_keys:
A sequence of foreign keys to use for searching the dependent_objects
for given object. By default this is None, indicating that all foreign
keys referencing the object will be used.
.. note::
This function does not support exotic mappers that use multiple tables
.. seealso:: :func:`get_referencing_foreign_keys`
.. seealso:: :func:`merge_references`
.. versionadded: 0.26.0
"""
if foreign_keys is None:
foreign_keys = get_referencing_foreign_keys(obj)
session = object_session(obj)
chain = QueryChain([])
classes = _get_class_registry(obj.__class__)
for table, keys in group_foreign_keys(foreign_keys):
keys = list(keys)
for class_ in classes.values():
try:
mapper = sa.inspect(class_)
except NoInspectionAvailable:
continue
parent_mapper = mapper.inherits
if (
table in mapper.tables and
not (parent_mapper and table in parent_mapper.tables)
):
query = session.query(class_).filter(
sa.or_(*_get_criteria(keys, class_, obj))
)
chain.queries.append(query)
return chain
def _get_criteria(keys, class_, obj):
criteria = []
visited_constraints = []
for key in keys:
if key.constraint in visited_constraints:
continue
visited_constraints.append(key.constraint)
subcriteria = []
for index, column in enumerate(key.constraint.columns):
foreign_column = (
key.constraint.elements[index].column
)
subcriteria.append(
getattr(class_, get_column_key(class_, column)) ==
getattr(
obj,
sa.inspect(type(obj))
.get_property_by_column(
foreign_column
).key
)
)
criteria.append(sa.and_(*subcriteria))
return criteria
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
bind = getattr(metadata, 'bind', None)
if bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload_with=bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not has_index(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def get_fk_constraint_for_columns(table, *columns):
for constraint in table.constraints:
if list(constraint.columns.values()) == list(columns):
return constraint

View file

@ -0,0 +1,112 @@
import contextlib
import datetime
import inspect
import io
import re
import sqlalchemy as sa
def create_mock_engine(bind, stream=None):
"""Create a mock SQLAlchemy engine from the passed engine or bind URL.
:param bind: A SQLAlchemy engine or bind URL to mock.
:param stream: Render all DDL operations to the stream.
"""
if not isinstance(bind, str):
bind_url = str(bind.url)
else:
bind_url = bind
if stream is not None:
def dump(sql, *args, **kwargs):
class Compiler(type(sql._compiler(engine.dialect))):
def visit_bindparam(self, bindparam, *args, **kwargs):
return self.render_literal_value(
bindparam.value, bindparam.type)
def render_literal_value(self, value, type_):
if isinstance(value, int):
return str(value)
elif isinstance(value, (datetime.date, datetime.datetime)):
return "'%s'" % value
return super().render_literal_value(
value, type_)
text = str(Compiler(engine.dialect, sql).process(sql))
text = re.sub(r'\n+', '\n', text)
text = text.strip('\n').strip()
stream.write('\n%s;' % text)
else:
def dump(*args, **kw):
return None
try:
engine = sa.create_mock_engine(bind_url, executor=dump)
except AttributeError: # SQLAlchemy <1.4
engine = sa.create_engine(bind_url, strategy='mock', executor=dump)
return engine
@contextlib.contextmanager
def mock_engine(engine, stream=None):
"""Mocks out the engine specified in the passed bind expression.
Note this function is meant for convenience and protected usage. Do NOT
blindly pass user input to this function as it uses exec.
:param engine: A python expression that represents the engine to mock.
:param stream: Render all DDL operations to the stream.
"""
# Create a stream if not present.
if stream is None:
stream = io.StringIO()
# Navigate the stack and find the calling frame that allows the
# expression to execute.
for frame in inspect.stack()[1:]:
try:
frame = frame[0]
expression = '__target = %s' % engine
exec(expression, frame.f_globals, frame.f_locals)
target = frame.f_locals['__target']
break
except Exception:
pass
else:
raise ValueError('Not a valid python expression', engine)
# Evaluate the expression and get the target engine.
frame.f_locals['__mock'] = create_mock_engine(target, stream)
# Replace the target with our mock.
exec('%s = __mock' % engine, frame.f_globals, frame.f_locals)
# Give control back.
yield stream
# Put the target engine back.
frame.f_locals['__target'] = target
exec('%s = __target' % engine, frame.f_globals, frame.f_locals)
exec('del __target', frame.f_globals, frame.f_locals)
exec('del __mock', frame.f_globals, frame.f_locals)

View file

@ -0,0 +1,904 @@
from collections import OrderedDict
from functools import partial
from inspect import isclass
from operator import attrgetter
import sqlalchemy as sa
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError
try:
from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
except ImportError: # SQLAlchemy <1.4
from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
from ..utils import is_sequence
def get_class_by_table(base, table, data=None):
"""
Return declarative class associated with given table. If no class is found
this function returns `None`. If multiple classes were found (polymorphic
cases) additional `data` parameter can be given to hint which class
to return.
::
class User(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
get_class_by_table(Base, User.__table__) # User class
This function also supports models using single table inheritance.
Additional data paratemer should be provided in these case.
::
class Entity(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
type = sa.Column(sa.String)
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
class User(Entity):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
# Entity class
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
# User class
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
:param base: Declarative model base
:param table: SQLAlchemy Table object
:param data: Data row to determine the class in polymorphic scenarios
:return: Declarative class or None.
"""
found_classes = {
c for c in _get_class_registry(base).values()
if hasattr(c, '__table__') and c.__table__ is table
}
if len(found_classes) > 1:
if not data:
raise ValueError(
"Multiple declarative classes found for table '{}'. "
"Please provide data parameter for this function to be able "
"to determine polymorphic scenarios.".format(
table.name
)
)
else:
for cls in found_classes:
mapper = sa.inspect(cls)
polymorphic_on = mapper.polymorphic_on.name
if polymorphic_on in data:
if data[polymorphic_on] == mapper.polymorphic_identity:
return cls
raise ValueError(
"Multiple declarative classes found for table '{}'. Given "
"data row does not match any polymorphic identity of the "
"found classes.".format(
table.name
)
)
elif found_classes:
return found_classes.pop()
return None
def get_type(expr):
"""
Return the associated type with given Column, InstrumentedAttribute,
ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
For constructs wrapping columns this is the column type. For relationships
this function returns the relationship mapper class.
:param expr:
SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
similar SA construct.
::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User)
get_type(User.__table__.c.name) # sa.String()
get_type(User.name) # sa.String()
get_type(User.name.property) # sa.String()
get_type(Article.author) # User
.. versionadded: 0.30.9
"""
if hasattr(expr, 'type'):
return expr.type
elif isinstance(expr, InstrumentedAttribute):
expr = expr.property
if isinstance(expr, ColumnProperty):
return expr.columns[0].type
elif isinstance(expr, RelationshipProperty):
return expr.mapper.class_
raise TypeError("Couldn't inspect type.")
def cast_if(expression, type_):
"""
Produce a CAST expression but only if given expression is not of given type
already.
Assume we have a model with two fields id (Integer) and name (String).
::
import sqlalchemy as sa
from sqlalchemy_utils import cast_if
cast_if(User.id, sa.Integer) # "user".id
cast_if(User.name, sa.String) # "user".name
cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
This function supports scalar values as well.
::
cast_if(1, sa.Integer) # 1
cast_if('text', sa.String) # 'text'
cast_if(1, sa.String) # CAST(1 AS TEXT)
:param expression:
A SQL expression, such as a ColumnElement expression or a Python string
which will be coerced into a bound literal value.
:param type_:
A TypeEngine class or instance indicating the type to which the CAST
should apply.
.. versionadded: 0.30.14
"""
try:
expr_type = get_type(expression)
except TypeError:
expr_type = expression
check_type = type_().python_type
else:
check_type = type_
return (
sa.cast(expression, type_)
if not isinstance(expr_type, check_type)
else expression
)
def get_column_key(model, column):
"""
Return the key for given column in given model.
:param model: SQLAlchemy declarative model object
::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.String)
get_column_key(User, User.__table__.c._name) # 'name'
.. versionadded: 0.26.5
.. versionchanged: 0.27.11
Throws UnmappedColumnError instead of ValueError when no property was
found for given column. This is consistent with how SQLAlchemy works.
"""
mapper = sa.inspect(model)
try:
return mapper.get_property_by_column(column).key
except sa.orm.exc.UnmappedColumnError:
for key, c in mapper.columns.items():
if c.name == column.name and c.table is column.table:
return key
raise sa.orm.exc.UnmappedColumnError(
'No column %s is configured on mapper %s...' %
(column, mapper)
)
def get_mapper(mixed):
"""
Return related SQLAlchemy Mapper for given SQLAlchemy object.
:param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
::
from sqlalchemy_utils import get_mapper
get_mapper(User)
get_mapper(User())
get_mapper(User.__table__)
get_mapper(User.__mapper__)
get_mapper(sa.orm.aliased(User))
get_mapper(sa.orm.aliased(User.__table__))
Raises:
ValueError: if multiple mappers were found for given argument
.. versionadded: 0.26.1
"""
if isinstance(mixed, _MapperEntity):
mixed = mixed.expr
elif isinstance(mixed, sa.Column):
mixed = mixed.table
elif isinstance(mixed, _ColumnEntity):
mixed = mixed.expr
if isinstance(mixed, sa.orm.Mapper):
return mixed
if isinstance(mixed, sa.orm.util.AliasedClass):
return sa.inspect(mixed).mapper
if isinstance(mixed, sa.sql.selectable.Alias):
mixed = mixed.element
if isinstance(mixed, AliasedInsp):
return mixed.mapper
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
mixed = mixed.class_
if isinstance(mixed, sa.Table):
if hasattr(mapperlib, '_all_registries'):
all_mappers = set()
for mapper_registry in mapperlib._all_registries():
all_mappers.update(mapper_registry.mappers)
else: # SQLAlchemy <1.4
all_mappers = mapperlib._mapper_registry
mappers = [
mapper for mapper in all_mappers
if mixed in mapper.tables
]
if len(mappers) > 1:
raise ValueError(
"Multiple mappers found for table '%s'." % mixed.name
)
elif not mappers:
raise ValueError(
"Could not get mapper for table '%s'." % mixed.name
)
else:
return mappers[0]
if not isclass(mixed):
mixed = type(mixed)
return sa.inspect(mixed)
def get_bind(obj):
"""
Return the bind for given SQLAlchemy Engine / Connection / declarative
model object.
:param obj: SQLAlchemy Engine / Connection / declarative model object
::
from sqlalchemy_utils import get_bind
get_bind(session) # Connection object
get_bind(user)
"""
if hasattr(obj, 'bind'):
conn = obj.bind
else:
try:
conn = object_session(obj).bind
except UnmappedInstanceError:
conn = obj
if not hasattr(conn, 'execute'):
raise TypeError(
'This method accepts only Session, Engine, Connection and '
'declarative model objects.'
)
return conn
def get_primary_keys(mixed):
"""
Return an OrderedDict of all primary keys for given Table object,
declarative class or declarative class instance.
:param mixed:
SA Table object, SA declarative class or SA declarative class instance
::
get_primary_keys(User)
get_primary_keys(User())
get_primary_keys(User.__table__)
get_primary_keys(User.__mapper__)
get_primary_keys(sa.orm.aliased(User))
get_primary_keys(sa.orm.aliased(User.__table__))
.. versionchanged: 0.25.3
Made the function return an ordered dictionary instead of generator.
This change was made to support primary key aliases.
Renamed this function to 'get_primary_keys', formerly 'primary_keys'
.. seealso:: :func:`get_columns`
"""
return OrderedDict(
(
(key, column) for key, column in get_columns(mixed).items()
if column.primary_key
)
)
def get_tables(mixed):
"""
Return a set of tables associated with given SQLAlchemy object.
Let's say we have three classes which use joined table inheritance
TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
::
get_tables(Article) # set([Table('article', ...), Table('text_item')])
get_tables(Article())
get_tables(Article.__mapper__)
If the TextItem entity is using with_polymorphic='*' then this function
returns all child tables (article and blog_post) as well.
::
get_tables(TextItem) # set([Table('text_item', ...)], ...])
.. versionadded: 0.26.0
:param mixed:
SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
a SA Alias object wrapping any of these objects.
"""
if isinstance(mixed, sa.Table):
return [mixed]
elif isinstance(mixed, sa.Column):
return [mixed.table]
elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
return mixed.parent.tables
elif isinstance(mixed, _ColumnEntity):
mixed = mixed.expr
mapper = get_mapper(mixed)
polymorphic_mappers = get_polymorphic_mappers(mapper)
if polymorphic_mappers:
tables = sum((m.tables for m in polymorphic_mappers), [])
else:
tables = mapper.tables
return tables
def get_columns(mixed):
"""
Return a collection of all Column objects for given SQLAlchemy
object.
The type of the collection depends on the type of the object to return the
columns from.
::
get_columns(User)
get_columns(User())
get_columns(User.__table__)
get_columns(User.__mapper__)
get_columns(sa.orm.aliased(User))
get_columns(sa.orm.alised(User.__table__))
:param mixed:
SA Table object, SA Mapper, SA declarative class, SA declarative class
instance or an alias of any of these objects
"""
if isinstance(mixed, sa.sql.selectable.Selectable):
try:
return mixed.selected_columns
except AttributeError: # SQLAlchemy <1.4
return mixed.c
if isinstance(mixed, sa.orm.util.AliasedClass):
return sa.inspect(mixed).mapper.columns
if isinstance(mixed, sa.orm.Mapper):
return mixed.columns
if isinstance(mixed, InstrumentedAttribute):
return mixed.property.columns
if isinstance(mixed, ColumnProperty):
return mixed.columns
if isinstance(mixed, sa.Column):
return [mixed]
if not isclass(mixed):
mixed = mixed.__class__
return sa.inspect(mixed).columns
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def quote(mixed, ident):
"""
Conditionally quote an identifier.
::
from sqlalchemy_utils import quote
engine = create_engine('sqlite:///:memory:')
quote(engine, 'order')
# '"order"'
quote(engine, 'some_other_identifier')
# 'some_other_identifier'
:param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
:param ident: identifier to conditionally quote
"""
if isinstance(mixed, Dialect):
dialect = mixed
else:
dialect = get_bind(mixed).dialect
return dialect.preparer(dialect).quote(ident)
def _get_query_compile_state(query):
if hasattr(query, '_compile_state'):
return query._compile_state()
else: # SQLAlchemy <1.4
return query
def get_polymorphic_mappers(mixed):
if isinstance(mixed, AliasedInsp):
return mixed.with_polymorphic_mappers
else:
return mixed.polymorphic_map.values()
def get_descriptor(entity, attr):
mapper = sa.inspect(entity)
for key, descriptor in get_all_descriptors(mapper).items():
if attr == key:
prop = (
descriptor.property
if hasattr(descriptor, 'property')
else None
)
if isinstance(prop, ColumnProperty):
if isinstance(entity, sa.orm.util.AliasedClass):
for c in mapper.selectable.c:
if c.key == attr:
return c
else:
# If the property belongs to a class that uses
# polymorphic inheritance we have to take into account
# situations where the attribute exists in child class
# but not in parent class.
return getattr(prop.parent.class_, attr)
else:
# Handle synonyms, relationship properties and hybrid
# properties
if isinstance(entity, sa.orm.util.AliasedClass):
return getattr(entity, attr)
try:
return getattr(mapper.class_, attr)
except AttributeError:
pass
def get_all_descriptors(expr):
if isinstance(expr, sa.sql.selectable.Selectable):
return expr.c
insp = sa.inspect(expr)
try:
polymorphic_mappers = get_polymorphic_mappers(insp)
except sa.exc.NoInspectionAvailable:
return get_mapper(expr).all_orm_descriptors
else:
attrs = dict(get_mapper(expr).all_orm_descriptors)
for submapper in polymorphic_mappers:
for key, descriptor in submapper.all_orm_descriptors.items():
if key not in attrs:
attrs[key] = descriptor
return attrs
def get_hybrid_properties(model):
"""
Returns a dictionary of hybrid property keys and hybrid properties for
given SQLAlchemy declarative model / mapper.
Consider the following model
::
from sqlalchemy.ext.hybrid import hybrid_property
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
You can now easily get a list of all hybrid property names
::
from sqlalchemy_utils import get_hybrid_properties
get_hybrid_properties(Category).keys() # ['lowercase_name']
This function also supports aliased classes
::
get_hybrid_properties(
sa.orm.aliased(Category)
).keys() # ['lowercase_name']
.. versionchanged: 0.26.7
This function now returns a dictionary instead of generator
.. versionchanged: 0.30.15
Added support for aliased classes
:param model: SQLAlchemy declarative model or mapper
"""
return {
key: prop
for key, prop in get_mapper(model).all_orm_descriptors.items()
if isinstance(prop, hybrid_property)
}
def get_declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return get_declarative_base(parent)
except AttributeError:
pass
return model
def getdotattr(obj_or_class, dot_path, condition=None):
"""
Allow dot-notated strings to be passed to `getattr`.
::
getdotattr(SubSection, 'section.document')
getdotattr(subsection, 'section.document')
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
last = obj_or_class
for path in str(dot_path).split('.'):
getter = attrgetter(path)
if is_sequence(last):
tmp = []
for element in last:
value = getter(element)
if is_sequence(value):
tmp.extend(value)
else:
tmp.append(value)
last = tmp
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
elif last is None:
return None
else:
last = getter(last)
if condition is not None:
if is_sequence(last):
last = [v for v in last if condition(v)]
else:
if not condition(last):
return None
return last
def is_deleted(obj):
return obj in sa.orm.object_session(obj).deleted
def has_changes(obj, attrs=None, exclude=None):
"""
Simple shortcut function for checking if given attributes of given
declarative model object have changed during the session. Without
parameters this checks if given object has any modificiations. Additionally
exclude parameter can be given to check if given object has any changes
in any attributes other than the ones given in exclude.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = 'someone'
has_changes(user, 'name') # True
has_changes(user) # True
You can check multiple attributes as well.
::
has_changes(user, ['age']) # True
has_changes(user, ['name', 'age']) # True
This function also supports excluding certain attributes.
::
has_changes(user, exclude=['name']) # False
has_changes(user, exclude=['age']) # True
.. versionchanged: 0.26.6
Added support for multiple attributes and exclude parameter.
:param obj: SQLAlchemy declarative model object
:param attrs: Names of the attributes
:param exclude: Names of the attributes to exclude
"""
if attrs:
if isinstance(attrs, str):
return (
sa.inspect(obj)
.attrs
.get(attrs)
.history
.has_changes()
)
else:
return any(has_changes(obj, attr) for attr in attrs)
else:
if exclude is None:
exclude = []
return any(
attr.history.has_changes()
for key, attr in sa.inspect(obj).attrs.items()
if key not in exclude
)
def is_loaded(obj, prop):
"""
Return whether or not given property of given object has been loaded.
::
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
content = sa.orm.deferred(sa.Column(sa.String))
article = session.query(Article).get(5)
# name gets loaded since its not a deferred property
assert is_loaded(article, 'name')
# content has not yet been loaded since its a deferred property
assert not is_loaded(article, 'content')
.. versionadded: 0.27.8
:param obj: SQLAlchemy declarative model object
:param prop: Name of the property or InstrumentedAttribute
"""
return prop not in sa.inspect(obj).unloaded
def identity(obj_or_class):
"""
Return the identity of given sqlalchemy declarative model class or instance
as a tuple. This differs from obj._sa_instance_state.identity in a way that
it always returns the identity even if object is still in transient state (
new object that is not yet persisted into database). Also for classes it
returns the identity attributes.
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name='John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
You can also use identity for classes::
identity(User) # (User.id, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
return tuple(
getattr(obj_or_class, column_key)
for column_key in get_primary_keys(obj_or_class).keys()
)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name='someone')
user2 = User(name='someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for column_key, column in sa.inspect(obj.__class__).columns.items():
if column.primary_key:
continue
if not (getattr(obj, column_key) == getattr(obj2, column_key)):
return False
return True
def _get_class_registry(class_):
try:
return class_.registry._class_registry
except AttributeError: # SQLAlchemy <1.4
return class_._decl_class_registry

View file

@ -0,0 +1,75 @@
import inspect
import io
import sqlalchemy as sa
from .mock import create_mock_engine
from .orm import _get_query_compile_state
def render_expression(expression, bind, stream=None):
"""Generate a SQL expression from the passed python expression.
Only the global variable, `engine`, is available for use in the
expression. Additional local variables may be passed in the context
parameter.
Note this function is meant for convenience and protected usage. Do NOT
blindly pass user input to this function as it uses exec.
:param bind: A SQLAlchemy engine or bind URL.
:param stream: Render all DDL operations to the stream.
"""
# Create a stream if not present.
if stream is None:
stream = io.StringIO()
engine = create_mock_engine(bind, stream)
# Navigate the stack and find the calling frame that allows the
# expression to execuate.
for frame in inspect.stack()[1:]:
try:
frame = frame[0]
local = dict(frame.f_locals)
local['engine'] = engine
exec(expression, frame.f_globals, local)
break
except Exception:
pass
else:
raise ValueError('Not a valid python expression', engine)
return stream
def render_statement(statement, bind=None):
"""
Generate an SQL expression string with bound parameters rendered inline
for the given SQLAlchemy statement.
:param statement: SQLAlchemy Query object.
:param bind:
Optional SQLAlchemy bind, if None uses the bind of the given query
object.
"""
if isinstance(statement, sa.orm.query.Query):
if bind is None:
bind = statement.session.get_bind(
_get_query_compile_state(statement)._mapper_zero()
)
statement = statement.statement
elif bind is None:
bind = statement.bind
stream = io.StringIO()
engine = create_mock_engine(bind.engine, stream=stream)
engine.execute(statement)
return stream.getvalue()

View file

@ -0,0 +1,74 @@
import sqlalchemy as sa
from .database import has_unique_index
from .orm import _get_query_compile_state, get_tables
def make_order_by_deterministic(query):
"""
Make query order by deterministic (if it isn't already). Order by is
considered deterministic if it contains column that is unique index (
either it is a primary key or has a unique index). Many times it is design
flaw to order by queries in nondeterministic manner.
Consider a User model with three fields: id (primary key), favorite color
and email (unique).::
from sqlalchemy_utils import make_order_by_deterministic
query = session.query(User).order_by(User.favorite_color)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id'
query = session.query(User).order_by(User.email)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".email'
query = session.query(User).order_by(User.id)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".id'
.. versionadded: 0.27.1
"""
order_by_func = sa.asc
try:
order_by_clauses = query._order_by_clauses
except AttributeError: # SQLAlchemy <1.4
order_by_clauses = query._order_by
if not order_by_clauses:
column = None
else:
order_by = order_by_clauses[0]
if isinstance(order_by, sa.sql.elements._label_reference):
order_by = order_by.element
if isinstance(order_by, sa.sql.expression.UnaryExpression):
if order_by.modifier == sa.sql.operators.desc_op:
order_by_func = sa.desc
else:
order_by_func = sa.asc
column = list(order_by.get_children())[0]
else:
column = order_by
# Skip queries that are ordered by an already deterministic column
if isinstance(column, sa.Column):
try:
if has_unique_index(column):
return query
except TypeError:
pass
base_table = get_tables(_get_query_compile_state(query)._entities[0])[0]
query = query.order_by(
*(order_by_func(c) for c in base_table.c if c.primary_key)
)
return query