Cleaned up the directories

This commit is contained in:
ComputerTech312 2024-02-19 15:34:25 +01:00
parent f708506d68
commit a683fcffea
1340 changed files with 554582 additions and 6840 deletions

View file

@ -0,0 +1,104 @@
from .aggregates import aggregated # noqa
from .asserts import ( # noqa
assert_max_length,
assert_max_value,
assert_min_value,
assert_non_nullable,
assert_nullable
)
from .exceptions import ImproperlyConfigured # noqa
from .expressions import Asterisk, row_to_json # noqa
from .functions import ( # noqa
cast_if,
create_database,
create_mock_engine,
database_exists,
dependent_objects,
drop_database,
escape_like,
get_bind,
get_class_by_table,
get_column_key,
get_columns,
get_declarative_base,
get_fk_constraint_for_columns,
get_hybrid_properties,
get_mapper,
get_primary_keys,
get_referencing_foreign_keys,
get_tables,
get_type,
group_foreign_keys,
has_changes,
has_index,
has_unique_index,
identity,
is_loaded,
json_sql,
jsonb_sql,
merge_references,
mock_engine,
naturally_equivalent,
render_expression,
render_statement,
table_name
)
from .generic import generic_relationship # noqa
from .i18n import TranslationHybrid # noqa
from .listeners import ( # noqa
auto_delete_orphans,
coercion_listener,
force_auto_coercion,
force_instant_defaults
)
from .models import generic_repr, Timestamp # noqa
from .observer import observes # noqa
from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa
from .proxy_dict import proxy_dict, ProxyDict # noqa
from .query_chain import QueryChain # noqa
from .types import ( # noqa
ArrowType,
Choice,
ChoiceType,
ColorType,
CompositeType,
CountryType,
CurrencyType,
DateRangeType,
DateTimeRangeType,
EmailType,
EncryptedType,
EnrichedDateTimeType,
EnrichedDateType,
instrumented_list,
InstrumentedList,
Int8RangeType,
IntRangeType,
IPAddressType,
JSONType,
LocaleType,
LtreeType,
NumericRangeType,
Password,
PasswordType,
PhoneNumber,
PhoneNumberParseException,
PhoneNumberType,
register_composites,
remove_composite_listeners,
ScalarListException,
ScalarListType,
StringEncryptedType,
TimezoneType,
TSVectorType,
URLType,
UUIDType,
WeekDaysType
)
from .view import ( # noqa
create_materialized_view,
create_view,
refresh_materialized_view
)
__version__ = '0.41.1'

View file

@ -0,0 +1,576 @@
"""
SQLAlchemy-Utils provides way of automatically calculating aggregate values of
related models and saving them to parent model.
This solution is inspired by RoR counter cache,
`counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
Why?
----
Many times you may have situations where you need to calculate dynamically some
aggregate value for given model. Some simple examples include:
- Number of products in a catalog
- Average rating for movie
- Latest forum post
- Total price of orders for given customer
Now all these aggregates can be elegantly implemented with SQLAlchemy
column_property_ function. However when your data grows calculating these
values on the fly might start to hurt the performance of your application. The
more aggregates you are using the more performance penalty you get.
This module provides way of calculating these values automatically and
efficiently at the time of modification rather than on the fly.
Features
--------
* Automatically updates aggregate columns when aggregated values change
* Supports aggregate values through arbitrary number levels of relations
* Highly optimized: uses single query per transaction per aggregate column
* Aggregated columns can be of any data type and use any selectable scalar
expression
.. _column_property:
https://docs.sqlalchemy.org/en/latest/orm/mapped_sql_expr.html#using-column-property
.. _counter_culture: https://github.com/magnusvk/counter_culture
.. _stackoverflow reply by Michael Bayer:
https://stackoverflow.com/a/13765857/520932
Simple aggregates
-----------------
::
from sqlalchemy_utils import aggregated
class Thread(Base):
__tablename__ = 'thread'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('comments', sa.Column(sa.Integer))
def comment_count(self):
return sa.func.count('1')
comments = sa.orm.relationship(
'Comment',
backref='thread'
)
class Comment(Base):
__tablename__ = 'comment'
id = sa.Column(sa.Integer, primary_key=True)
content = sa.Column(sa.UnicodeText)
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
thread = Thread(name='SQLAlchemy development')
thread.comments.append(Comment('Going good!'))
thread.comments.append(Comment('Great new features!'))
session.add(thread)
session.commit()
thread.comment_count # 2
Custom aggregate expressions
----------------------------
Aggregate expression can be virtually any SQL expression not just a simple
function taking one parameter. You can try things such as subqueries and
different kinds of functions.
In the following example we have a Catalog of products where each catalog
knows the net worth of its products.
::
from sqlalchemy_utils import aggregated
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('products', sa.Column(sa.Integer))
def net_worth(self):
return sa.func.sum(Product.price)
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
Now the net_worth column of Catalog model will be automatically whenever:
* A new product is added to the catalog
* A product is deleted from the catalog
* The price of catalog product is changed
::
from decimal import Decimal
product1 = Product(name='Some product', price=Decimal(1000))
product2 = Product(name='Some other product', price=Decimal(500))
catalog = Catalog(
name='My first catalog',
products=[
product1,
product2
]
)
session.add(catalog)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1500
session.delete(product2)
session.commit()
session.refresh(catalog)
catalog.net_worth # 1000
product1.price = 2000
session.commit()
session.refresh(catalog)
catalog.net_worth # 2000
Multiple aggregates per class
-----------------------------
Sometimes you may need to define multiple aggregate values for same class. If
you need to define lots of relationships pointing to same class, remember to
define the relationships as viewonly when possible.
::
from sqlalchemy_utils import aggregated
class Customer(Base):
__tablename__ = 'customer'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('orders', sa.Column(sa.Integer))
def orders_sum(self):
return sa.func.sum(Order.price)
@aggregated('invoiced_orders', sa.Column(sa.Integer))
def invoiced_orders_sum(self):
return sa.func.sum(Order.price)
orders = sa.orm.relationship('Order')
invoiced_orders = sa.orm.relationship(
'Order',
primaryjoin=
'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
viewonly=True
)
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
invoiced = sa.Column(sa.Boolean, default=False)
customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
Many-to-Many aggregates
-----------------------
Aggregate expressions also support many-to-many relationships. The usual use
scenarios includes things such as:
1. Friend count of a user
2. Group count where given user belongs to
::
user_group = sa.Table('user_group', Base.metadata,
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
)
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('groups', sa.Column(sa.Integer, default=0))
def group_count(self):
return sa.func.count('1')
groups = sa.orm.relationship(
'Group',
backref='users',
secondary=user_group
)
class Group(Base):
__tablename__ = 'group'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
user = User(name='John Matrix')
user.groups = [Group(name='Group A'), Group(name='Group B')]
session.add(user)
session.commit()
session.refresh(user)
user.group_count # 2
Multi-level aggregates
----------------------
Aggregates can span across multiple relationships. In the following example
each Catalog has a net_worth which is the sum of all products in all
categories.
::
from sqlalchemy_utils import aggregated
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('categories.products', sa.Column(sa.Integer))
def net_worth(self):
return sa.func.sum(Product.price)
categories = sa.orm.relationship('Category')
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
products = sa.orm.relationship('Product')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
Examples
--------
Average movie rating
^^^^^^^^^^^^^^^^^^^^
::
from sqlalchemy_utils import aggregated
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@aggregated('ratings', sa.Column(sa.Numeric))
def avg_rating(self):
return sa.func.avg(Rating.stars)
ratings = sa.orm.relationship('Rating')
class Rating(Base):
__tablename__ = 'rating'
id = sa.Column(sa.Integer, primary_key=True)
stars = sa.Column(sa.Integer)
movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id))
movie = Movie('Terminator 2')
movie.ratings.append(Rating(stars=5))
movie.ratings.append(Rating(stars=4))
movie.ratings.append(Rating(stars=3))
session.add(movie)
session.commit()
movie.avg_rating # 4
TODO
----
* Special consideration should be given to `deadlocks`_.
.. _deadlocks:
https://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
"""
from collections import defaultdict
from weakref import WeakKeyDictionary
import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.orm
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.sql.functions import _FunctionGenerator
from .compat import _select_args, get_scalar_subquery
from .functions.orm import get_column_key
from .relationships import (
chained_join,
path_to_relationships,
select_correlated_expression
)
aggregated_attrs = WeakKeyDictionary()
class AggregatedAttribute(declared_attr):
def __init__(
self,
fget,
relationship,
column,
*args,
**kwargs
):
super().__init__(fget, *args, **kwargs)
self.__doc__ = fget.__doc__
self.column = column
self.relationship = relationship
def __get__(desc, self, cls):
value = (desc.fget, desc.relationship, desc.column)
if cls not in aggregated_attrs:
aggregated_attrs[cls] = [value]
else:
aggregated_attrs[cls].append(value)
return desc.column
def local_condition(prop, objects):
pairs = prop.local_remote_pairs
if prop.secondary is not None:
parent_column = pairs[1][0]
fetched_column = pairs[1][0]
else:
parent_column = pairs[0][0]
fetched_column = pairs[0][1]
key = get_column_key(prop.mapper, fetched_column)
values = []
for obj in objects:
try:
values.append(getattr(obj, key))
except sa.orm.exc.ObjectDeletedError:
pass
if values:
return parent_column.in_(values)
def aggregate_expression(expr, class_):
if isinstance(expr, sa.sql.visitors.Visitable):
return expr
elif isinstance(expr, _FunctionGenerator):
return expr(sa.sql.text('1'))
else:
return expr(class_)
class AggregatedValue:
def __init__(self, class_, attr, path, expr):
self.class_ = class_
self.attr = attr
self.path = path
self.relationships = list(
reversed(path_to_relationships(path, class_))
)
self.expr = aggregate_expression(expr, class_)
@property
def aggregate_query(self):
query = select_correlated_expression(
self.class_,
self.expr,
self.path,
self.relationships[0].mapper.class_
)
return get_scalar_subquery(query)
def update_query(self, objects):
table = self.class_.__table__
query = table.update().values(
{self.attr: self.aggregate_query}
)
if len(self.relationships) == 1:
prop = self.relationships[-1].property
condition = local_condition(prop, objects)
if condition is not None:
return query.where(condition)
else:
# Builds query such as:
#
# UPDATE catalog SET product_count = (aggregate_query)
# WHERE id IN (
# SELECT catalog_id
# FROM category
# INNER JOIN sub_category
# ON category.id = sub_category.category_id
# WHERE sub_category.id IN (product_sub_category_ids)
# )
property_ = self.relationships[-1].property
remote_pairs = property_.local_remote_pairs
local = remote_pairs[0][0]
remote = remote_pairs[0][1]
condition = local_condition(
self.relationships[0].property,
objects
)
if condition is not None:
return query.where(
local.in_(
sa.select(
*_select_args(remote)
).select_from(
chained_join(*reversed(self.relationships))
).where(
condition
)
)
)
class AggregationManager:
def __init__(self):
self.reset()
def reset(self):
self.generator_registry = defaultdict(list)
def register_listeners(self):
sa.event.listen(
sa.orm.Mapper,
'after_configured',
self.update_generator_registry
)
sa.event.listen(
sa.orm.session.Session,
'after_flush',
self.construct_aggregate_queries
)
def update_generator_registry(self):
for class_, attrs in aggregated_attrs.items():
for expr, path, column in attrs:
value = AggregatedValue(
class_=class_,
attr=column,
path=path,
expr=expr(class_)
)
key = value.relationships[0].mapper.class_
self.generator_registry[key].append(
value
)
def construct_aggregate_queries(self, session, ctx):
object_dict = defaultdict(list)
for obj in session:
for class_ in self.generator_registry:
if isinstance(obj, class_):
object_dict[class_].append(obj)
for class_, objects in object_dict.items():
for aggregate_value in self.generator_registry[class_]:
query = aggregate_value.update_query(objects)
if query is not None:
session.execute(query)
manager = AggregationManager()
manager.register_listeners()
def aggregated(
relationship,
column
):
"""
Decorator that generates an aggregated attribute. The decorated function
should return an aggregate select expression.
:param relationship:
Defines the relationship of which the aggregate is calculated from.
The class needs to have given relationship in order to calculate the
aggregate.
:param column:
SQLAlchemy Column object. The column definition of this aggregate
attribute.
"""
def wraps(func):
return AggregatedAttribute(
func,
relationship,
column
)
return wraps

View file

@ -0,0 +1,182 @@
"""
The functions in this module can be used for testing that the constraints of
your models. Each assert function runs SQL UPDATEs that check for the existence
of given constraint. Consider the following model::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(200), nullable=True)
email = sa.Column(sa.String(255), nullable=False)
user = User(name='John Doe', email='john@example.com')
session.add(user)
session.commit()
We can easily test the constraints by assert_* functions::
from sqlalchemy_utils import (
assert_nullable,
assert_non_nullable,
assert_max_length
)
assert_nullable(user, 'name')
assert_non_nullable(user, 'email')
assert_max_length(user, 'name', 200)
# raises AssertionError because the max length of email is 255
assert_max_length(user, 'email', 300)
"""
from decimal import Decimal
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.exc import DataError, IntegrityError
def _update_field(obj, field, value):
session = sa.orm.object_session(obj)
column = sa.inspect(obj.__class__).columns[field]
query = column.table.update().values(**{column.key: value})
session.execute(query)
session.flush()
def _expect_successful_update(obj, field, value, reraise_exc):
try:
_update_field(obj, field, value)
except (reraise_exc) as e:
session = sa.orm.object_session(obj)
session.rollback()
assert False, str(e)
def _expect_failing_update(obj, field, value, expected_exc):
try:
_update_field(obj, field, value)
except expected_exc:
pass
else:
raise AssertionError('Expected update to raise %s' % expected_exc)
finally:
session = sa.orm.object_session(obj)
session.rollback()
def _repeated_value(type_):
if isinstance(type_, ARRAY):
if isinstance(type_.item_type, sa.Integer):
return [0]
elif isinstance(type_.item_type, sa.String):
return ['a']
elif isinstance(type_.item_type, sa.Numeric):
return [Decimal('0')]
else:
raise TypeError('Unknown array item type')
else:
return 'a'
def _expected_exception(type_):
if isinstance(type_, ARRAY):
return IntegrityError
else:
return DataError
def assert_nullable(obj, column):
"""
Assert that given column is nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_successful_update(obj, column, None, IntegrityError)
def assert_non_nullable(obj, column):
"""
Assert that given column is not nullable. This is checked by running an SQL
update that assigns given column as None.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
"""
_expect_failing_update(obj, column, None, IntegrityError)
def assert_max_length(obj, column, max_length):
"""
Assert that the given column is of given max length. This function supports
string typed columns as well as PostgreSQL array typed columns.
In the following example we add a check constraint that user can have a
maximum of 5 favorite colors and then test this.::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
__table_args__ = (
sa.CheckConstraint(
sa.func.array_length(favorite_colors, 1) <= 5
)
)
user = User(name='John Doe', favorite_colors=['red', 'blue'])
session.add(user)
session.commit()
assert_max_length(user, 'favorite_colors', 5)
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_length: Maximum length of given column
"""
type_ = sa.inspect(obj.__class__).columns[column].type
_expect_successful_update(
obj,
column,
_repeated_value(type_) * max_length,
_expected_exception(type_)
)
_expect_failing_update(
obj,
column,
_repeated_value(type_) * (max_length + 1),
_expected_exception(type_)
)
def assert_min_value(obj, column, min_value):
"""
Assert that the given column must have a minimum value of `min_value`.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param min_value: The minimum allowed value for given column
"""
_expect_successful_update(obj, column, min_value, IntegrityError)
_expect_failing_update(obj, column, min_value - 1, IntegrityError)
def assert_max_value(obj, column, min_value):
"""
Assert that the given column must have a minimum value of `max_value`.
:param obj: SQLAlchemy declarative model object
:param column: Name of the column
:param max_value: The maximum allowed value for given column
"""
_expect_successful_update(obj, column, min_value, IntegrityError)
_expect_failing_update(obj, column, min_value + 1, IntegrityError)

View file

@ -0,0 +1,86 @@
import re
import sys
if sys.version_info >= (3, 8):
from importlib.metadata import metadata
else:
from importlib_metadata import metadata
def get_sqlalchemy_version(version=metadata("sqlalchemy")["Version"]):
"""Extract the sqlalchemy version as a tuple of integers."""
match = re.search(r"^(\d+)(?:\.(\d+)(?:\.(\d+))?)?", version)
try:
return tuple(int(v) for v in match.groups() if v is not None)
except AttributeError:
return ()
_sqlalchemy_version = get_sqlalchemy_version()
# In sqlalchemy 2.0, some functions moved to sqlalchemy.orm.
# In sqlalchemy 1.3, they are only available in .ext.declarative.
# In sqlalchemy 1.4, they are available in both places.
#
# WARNING
# -------
#
# These imports are for internal, private compatibility.
# They are not supported and may change or move at any time.
# Do not import these in your own code.
#
if _sqlalchemy_version >= (1, 4):
from sqlalchemy.orm import declarative_base as _declarative_base
from sqlalchemy.orm import synonym_for as _synonym_for
else:
from sqlalchemy.ext.declarative import \
declarative_base as _declarative_base
from sqlalchemy.ext.declarative import synonym_for as _synonym_for
# scalar subqueries
if _sqlalchemy_version >= (1, 4):
def get_scalar_subquery(query):
return query.scalar_subquery()
else:
def get_scalar_subquery(query):
return query.as_scalar()
# In sqlalchemy 2.0, select() columns are positional.
# In sqlalchemy 1.3, select() columns must be wrapped in a list.
#
# _select_args() is designed so its return value can be unpacked:
#
# select(*_select_args(1, 2))
#
# When sqlalchemy 1.3 support is dropped, remove the call to _select_args()
# and keep the arguments the same:
#
# select(1, 2)
#
# WARNING
# -------
#
# _select_args() is a private, internal function.
# It is not supported and may change or move at any time.
# Do not import this in your own code.
#
if _sqlalchemy_version >= (1, 4):
def _select_args(*args):
return args
else:
def _select_args(*args):
return [args]
__all__ = (
"_declarative_base",
"get_scalar_subquery",
"get_sqlalchemy_version",
"_select_args",
"_synonym_for",
)

View file

@ -0,0 +1,10 @@
"""
Global SQLAlchemy-Utils exception classes.
"""
class ImproperlyConfigured(Exception):
"""
SQLAlchemy-Utils is improperly configured; normally due to usage of
a utility that depends on a missing library.
"""

View file

@ -0,0 +1,60 @@
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql.expression import ColumnElement, FunctionElement
from sqlalchemy.sql.functions import GenericFunction
from .functions.orm import quote
class array_get(FunctionElement):
name = 'array_get'
@compiles(array_get)
def compile_array_get(element, compiler, **kw):
args = list(element.clauses)
if len(args) != 2:
raise Exception(
"Function 'array_get' expects two arguments (%d given)." %
len(args)
)
if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
raise Exception(
"Second argument should be an integer."
)
return '({})[{}]'.format(
compiler.process(args[0]),
sa.text(str(args[1].value + 1))
)
class row_to_json(GenericFunction):
name = 'row_to_json'
type = postgresql.JSON
@compiles(row_to_json, 'postgresql')
def compile_row_to_json(element, compiler, **kw):
return f"{element.name}({compiler.process(element.clauses)})"
class json_array_length(GenericFunction):
name = 'json_array_length'
type = sa.Integer
@compiles(json_array_length, 'postgresql')
def compile_json_array_length(element, compiler, **kw):
return f"{element.name}({compiler.process(element.clauses)})"
class Asterisk(ColumnElement):
def __init__(self, selectable):
self.selectable = selectable
@compiles(Asterisk)
def compile_asterisk(element, compiler, **kw):
return '%s.*' % quote(compiler.dialect, element.selectable.name)

View file

@ -0,0 +1,42 @@
from .database import ( # noqa
create_database,
database_exists,
drop_database,
escape_like,
has_index,
has_unique_index,
is_auto_assigned_date_column,
json_sql,
jsonb_sql
)
from .foreign_keys import ( # noqa
dependent_objects,
get_fk_constraint_for_columns,
get_referencing_foreign_keys,
group_foreign_keys,
merge_references,
non_indexed_foreign_keys
)
from .mock import create_mock_engine, mock_engine # noqa
from .orm import ( # noqa
cast_if,
get_bind,
get_class_by_table,
get_column_key,
get_columns,
get_declarative_base,
get_hybrid_properties,
get_mapper,
get_primary_keys,
get_tables,
get_type,
getdotattr,
has_changes,
identity,
is_loaded,
naturally_equivalent,
quote,
table_name
)
from .render import render_expression, render_statement # noqa
from .sort_query import make_order_by_deterministic # noqa

View file

@ -0,0 +1,659 @@
import itertools
import os
from collections.abc import Mapping, Sequence
from copy import copy
import sqlalchemy as sa
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError, ProgrammingError
from ..utils import starts_with
from .orm import quote
def escape_like(string, escape_char='*'):
"""
Escape the string paremeter used in SQL LIKE expressions.
::
from sqlalchemy_utils import escape_like
query = session.query(User).filter(
User.name.ilike(escape_like('John'))
)
:param string: a string to escape
:param escape_char: escape character
"""
return (
string
.replace(escape_char, escape_char * 2)
.replace('%', escape_char + '%')
.replace('_', escape_char + '_')
)
def json_sql(value, scalars_to_json=True):
"""
Convert python data structures to PostgreSQL specific SQLAlchemy JSON
constructs. This function is extremly useful if you need to build
PostgreSQL JSON on python side.
.. note::
This function needs PostgreSQL >= 9.4
Scalars are converted to to_json SQLAlchemy function objects
::
json_sql(1) # Equals SQL: to_json(1)
json_sql('a') # to_json('a')
Mappings are converted to json_build_object constructs
::
json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5)
Sequences (other than strings) are converted to json_build_array constructs
::
json_sql([1, 2, 3]) # json_build_array(1, 2, 3)
You can also nest these data structures
::
json_sql({'a': [1, 2, 3]})
# json_build_object('a', json_build_array[1, 2, 3])
:param value:
value to be converted to SQLAlchemy PostgreSQL function constructs
"""
scalar_convert = sa.text
if scalars_to_json:
def scalar_convert(a):
return sa.func.to_json(sa.text(a))
if isinstance(value, Mapping):
return sa.func.json_build_object(
*(
json_sql(v, scalars_to_json=False)
for v in itertools.chain(*value.items())
)
)
elif isinstance(value, str):
return scalar_convert(f"'{value}'")
elif isinstance(value, Sequence):
return sa.func.json_build_array(
*(
json_sql(v, scalars_to_json=False)
for v in value
)
)
elif isinstance(value, (int, float)):
return scalar_convert(str(value))
return value
def jsonb_sql(value, scalars_to_jsonb=True):
"""
Convert python data structures to PostgreSQL specific SQLAlchemy JSONB
constructs. This function is extremly useful if you need to build
PostgreSQL JSONB on python side.
.. note::
This function needs PostgreSQL >= 9.4
Scalars are converted to to_jsonb SQLAlchemy function objects
::
jsonb_sql(1) # Equals SQL: to_jsonb(1)
jsonb_sql('a') # to_jsonb('a')
Mappings are converted to jsonb_build_object constructs
::
jsonb_sql({'a': 'c', '2': 5}) # jsonb_build_object('a', 'c', '2', 5)
Sequences (other than strings) converted to jsonb_build_array constructs
::
jsonb_sql([1, 2, 3]) # jsonb_build_array(1, 2, 3)
You can also nest these data structures
::
jsonb_sql({'a': [1, 2, 3]})
# jsonb_build_object('a', jsonb_build_array[1, 2, 3])
:param value:
value to be converted to SQLAlchemy PostgreSQL function constructs
:boolean jsonbb:
Flag to alternatively convert the return with a to_jsonb construct
"""
scalar_convert = sa.text
if scalars_to_jsonb:
def scalar_convert(a):
return sa.func.to_jsonb(sa.text(a))
if isinstance(value, Mapping):
return sa.func.jsonb_build_object(
*(
jsonb_sql(v, scalars_to_jsonb=False)
for v in itertools.chain(*value.items())
)
)
elif isinstance(value, str):
return scalar_convert(f"'{value}'")
elif isinstance(value, Sequence):
return sa.func.jsonb_build_array(
*(
jsonb_sql(v, scalars_to_jsonb=False)
for v in value
)
)
elif isinstance(value, (int, float)):
return scalar_convert(str(value))
return value
def has_index(column_or_constraint):
"""
Return whether or not given column or the columns of given foreign key
constraint have an index. A column has an index if it has a single column
index or it is the first column in compound column index.
A foreign key constraint has an index if the constraint columns are the
first columns in compound column index.
:param column_or_constraint:
SQLAlchemy Column object or SA ForeignKeyConstraint object
.. versionadded: 0.26.2
.. versionchanged: 0.30.18
Added support for foreign key constaints.
::
from sqlalchemy_utils import has_index
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
is_published = sa.Column(sa.Boolean, index=True)
is_deleted = sa.Column(sa.Boolean)
is_archived = sa.Column(sa.Boolean)
__table_args__ = (
sa.Index('my_index', is_deleted, is_archived),
)
table = Article.__table__
has_index(table.c.is_published) # True
has_index(table.c.is_deleted) # True
has_index(table.c.is_archived) # False
Also supports primary key indexes
::
from sqlalchemy_utils import has_index
class ArticleTranslation(Base):
__tablename__ = 'article_translation'
id = sa.Column(sa.Integer, primary_key=True)
locale = sa.Column(sa.String(10), primary_key=True)
title = sa.Column(sa.String(100))
table = ArticleTranslation.__table__
has_index(table.c.locale) # False
has_index(table.c.id) # True
This function supports foreign key constraints as well
::
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
author_last_name = sa.Column(sa.Unicode(255))
__table_args__ = (
sa.ForeignKeyConstraint(
[author_first_name, author_last_name],
[User.first_name, User.last_name]
),
sa.Index(
'my_index',
author_first_name,
author_last_name
)
)
table = Article.__table__
constraint = list(table.foreign_keys)[0].constraint
has_index(constraint) # True
"""
table = column_or_constraint.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
primary_keys = table.primary_key.columns.values()
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
columns = list(column_or_constraint.columns.values())
else:
columns = [column_or_constraint]
return (
(primary_keys and starts_with(primary_keys, columns)) or
any(
starts_with(index.columns.values(), columns)
for index in table.indexes
)
)
def has_unique_index(column_or_constraint):
"""
Return whether or not given column or given foreign key constraint has a
unique index.
A column has a unique index if it has a single column primary key index or
it has a single column UniqueConstraint.
A foreign key constraint has a unique index if the columns of the
constraint are the same as the columns of table primary key or the coluns
of any unique index or any unique constraint of the given table.
:param column: SQLAlchemy Column object
.. versionadded: 0.27.1
.. versionchanged: 0.30.18
Added support for foreign key constaints.
Fixed support for unique indexes (previously only worked for unique
constraints)
::
from sqlalchemy_utils import has_unique_index
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(100))
is_published = sa.Column(sa.Boolean, unique=True)
is_deleted = sa.Column(sa.Boolean)
is_archived = sa.Column(sa.Boolean)
table = Article.__table__
has_unique_index(table.c.is_published) # True
has_unique_index(table.c.is_deleted) # False
has_unique_index(table.c.id) # True
This function supports foreign key constraints as well
::
class User(Base):
__tablename__ = 'user'
first_name = sa.Column(sa.Unicode(255), primary_key=True)
last_name = sa.Column(sa.Unicode(255), primary_key=True)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_first_name = sa.Column(sa.Unicode(255))
author_last_name = sa.Column(sa.Unicode(255))
__table_args__ = (
sa.ForeignKeyConstraint(
[author_first_name, author_last_name],
[User.first_name, User.last_name]
),
sa.Index(
'my_index',
author_first_name,
author_last_name,
unique=True
)
)
table = Article.__table__
constraint = list(table.foreign_keys)[0].constraint
has_unique_index(constraint) # True
:raises TypeError: if given column does not belong to a Table object
"""
table = column_or_constraint.table
if not isinstance(table, sa.Table):
raise TypeError(
'Only columns belonging to Table objects are supported. Given '
'column belongs to %r.' % table
)
primary_keys = list(table.primary_key.columns.values())
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
columns = list(column_or_constraint.columns.values())
else:
columns = [column_or_constraint]
return (
(columns == primary_keys) or
any(
columns == list(constraint.columns.values())
for constraint in table.constraints
if isinstance(constraint, sa.sql.schema.UniqueConstraint)
) or
any(
columns == list(index.columns.values())
for index in table.indexes
if index.unique
)
)
def is_auto_assigned_date_column(column):
"""
Returns whether or not given SQLAlchemy Column object's is auto assigned
DateTime or Date.
:param column: SQLAlchemy Column object
"""
return (
(
isinstance(column.type, sa.DateTime) or
isinstance(column.type, sa.Date)
) and
(
column.default or
column.server_default or
column.onupdate or
column.server_onupdate
)
)
def _set_url_database(url: sa.engine.url.URL, database):
"""Set the database of an engine URL.
:param url: A SQLAlchemy engine URL.
:param database: New database to set.
"""
if hasattr(url, '_replace'):
# Cannot use URL.set() as database may need to be set to None.
ret = url._replace(database=database)
else: # SQLAlchemy <1.4
url = copy(url)
url.database = database
ret = url
assert ret.database == database, ret
return ret
def _get_scalar_result(engine, sql):
with engine.connect() as conn:
return conn.scalar(sql)
def _sqlite_file_exists(database):
if not os.path.isfile(database) or os.path.getsize(database) < 100:
return False
with open(database, 'rb') as f:
header = f.read(100)
return header[:16] == b'SQLite format 3\x00'
def database_exists(url):
"""Check if a database exists.
:param url: A SQLAlchemy engine URL.
Performs backend-specific testing to quickly determine if a database
exists on the server. ::
database_exists('postgresql://postgres@localhost/name') #=> False
create_database('postgresql://postgres@localhost/name')
database_exists('postgresql://postgres@localhost/name') #=> True
Supports checking against a constructed URL as well. ::
engine = create_engine('postgresql://postgres@localhost/name')
database_exists(engine.url) #=> False
create_database(engine.url)
database_exists(engine.url) #=> True
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
engine = None
try:
if dialect_name == 'postgresql':
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
for db in (database, 'postgres', 'template1', 'template0', None):
url = _set_url_database(url, database=db)
engine = sa.create_engine(url)
try:
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
pass
return False
elif dialect_name == 'mysql':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
"WHERE SCHEMA_NAME = '%s'" % database)
return bool(_get_scalar_result(engine, sa.text(text)))
elif dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
engine = sa.create_engine(url)
if database:
return database == ':memory:' or _sqlite_file_exists(database)
else:
# The default SQLAlchemy database is in memory, and :memory: is
# not required, thus we should support that use case.
return True
else:
text = 'SELECT 1'
try:
engine = sa.create_engine(url)
return bool(_get_scalar_result(engine, sa.text(text)))
except (ProgrammingError, OperationalError):
return False
finally:
if engine:
engine.dispose()
def create_database(url, encoding='utf8', template=None):
"""Issue the appropriate CREATE DATABASE statement.
:param url: A SQLAlchemy engine URL.
:param encoding: The encoding to create the database as.
:param template:
The name of the template from which to create the new database. At the
moment only supported by PostgreSQL driver.
To create a database, you can pass a simple URL that would have
been passed to ``create_engine``. ::
create_database('postgresql://postgres@localhost/name')
You may also pass the url from an existing engine. ::
create_database(engine.url)
Has full support for mysql, postgres, and sqlite. In theory,
other database engines should be supported.
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \
or (dialect_name == 'postgresql' and dialect_driver in {
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}):
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
else:
engine = sa.create_engine(url)
if dialect_name == 'postgresql':
if not template:
template = 'template1'
with engine.begin() as conn:
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
quote(conn, database),
encoding,
quote(conn, template)
)
conn.execute(sa.text(text))
elif dialect_name == 'mysql':
with engine.begin() as conn:
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
quote(conn, database),
encoding
)
conn.execute(sa.text(text))
elif dialect_name == 'sqlite' and database != ':memory:':
if database:
with engine.begin() as conn:
conn.execute(sa.text('CREATE TABLE DB(id int)'))
conn.execute(sa.text('DROP TABLE DB'))
else:
with engine.begin() as conn:
text = f'CREATE DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
engine.dispose()
def drop_database(url):
"""Issue the appropriate DROP DATABASE statement.
:param url: A SQLAlchemy engine URL.
Works similar to the :ref:`create_database` method in that both url text
and a constructed url are accepted. ::
drop_database('postgresql://postgres@localhost/name')
drop_database(engine.url)
"""
url = make_url(url)
database = url.database
dialect_name = url.get_dialect().name
dialect_driver = url.get_dialect().driver
if dialect_name == 'postgresql':
url = _set_url_database(url, database="postgres")
elif dialect_name == 'mssql':
url = _set_url_database(url, database="master")
elif dialect_name == 'cockroachdb':
url = _set_url_database(url, database="defaultdb")
elif not dialect_name == 'sqlite':
url = _set_url_database(url, database=None)
if dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}:
engine = sa.create_engine(url, connect_args={'autocommit': True})
elif dialect_name == 'postgresql' and dialect_driver in {
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}:
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
else:
engine = sa.create_engine(url)
if dialect_name == 'sqlite' and database != ':memory:':
if database:
os.remove(database)
elif dialect_name == 'postgresql':
with engine.begin() as conn:
# Disconnect all users from the database we are dropping.
version = conn.dialect.server_version_info
pid_column = (
'pid' if (version >= (9, 2)) else 'procpid'
)
text = '''
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
FROM pg_stat_activity
WHERE pg_stat_activity.datname = '{database}'
AND {pid_column} <> pg_backend_pid();
'''.format(pid_column=pid_column, database=database)
conn.execute(sa.text(text))
# Drop the database.
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
else:
with engine.begin() as conn:
text = f'DROP DATABASE {quote(conn, database)}'
conn.execute(sa.text(text))
engine.dispose()

View file

@ -0,0 +1,350 @@
from collections import defaultdict
from itertools import groupby
import sqlalchemy as sa
from sqlalchemy.exc import NoInspectionAvailable
from sqlalchemy.orm import object_session
from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
from ..query_chain import QueryChain
from .database import has_index
from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
def get_foreign_key_values(fk, obj):
mapper = get_mapper(obj)
return {
fk.constraint.columns.values()[index]:
getattr(obj, element.column.key)
if hasattr(obj, element.column.key)
else getattr(
obj, mapper.get_property_by_column(element.column).key
)
for index, element in enumerate(fk.constraint.elements)
}
def group_foreign_keys(foreign_keys):
"""
Return a groupby iterator that groups given foreign keys by table.
:param foreign_keys: a sequence of foreign keys
::
foreign_keys = get_referencing_foreign_keys(User)
for table, fks in group_foreign_keys(foreign_keys):
# do something
pass
.. seealso:: :func:`get_referencing_foreign_keys`
.. versionadded: 0.26.1
"""
foreign_keys = sorted(
foreign_keys, key=lambda key: key.constraint.table.name
)
return groupby(foreign_keys, lambda key: key.constraint.table)
def get_referencing_foreign_keys(mixed):
"""
Returns referencing foreign keys for given Table object or declarative
class.
:param mixed:
SA Table object or SA declarative class
::
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
get_referencing_foreign_keys(User.__table__)
This function also understands inheritance. This means it returns
all foreign keys that reference any table in the class inheritance tree.
Let's say you have three classes which use joined table inheritance,
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
TextItem.
::
# This will check all foreign keys that reference either article table
# or textitem table.
get_referencing_foreign_keys(Article)
.. seealso:: :func:`get_tables`
"""
if isinstance(mixed, sa.Table):
tables = [mixed]
else:
tables = get_tables(mixed)
referencing_foreign_keys = set()
for table in mixed.metadata.tables.values():
if table not in tables:
for constraint in table.constraints:
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
for fk in constraint.elements:
if any(fk.references(t) for t in tables):
referencing_foreign_keys.add(fk)
return referencing_foreign_keys
def merge_references(from_, to, foreign_keys=None):
"""
Merge the references of an entity into another entity.
Consider the following models::
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(255))
def __repr__(self):
return 'User(name=%r)' % self.name
class BlogPost(self.Base):
__tablename__ = 'blog_post'
id = sa.Column(sa.Integer, primary_key=True)
title = sa.Column(sa.String(255))
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
author = sa.orm.relationship(User)
Now lets add some data::
john = self.User(name='John')
jack = self.User(name='Jack')
post = self.BlogPost(title='Some title', author=john)
post2 = self.BlogPost(title='Other title', author=jack)
self.session.add_all([
john,
jack,
post,
post2
])
self.session.commit()
If we wanted to merge all John's references to Jack it would be as easy as
::
merge_references(john, jack)
self.session.commit()
post.author # User(name='Jack')
post2.author # User(name='Jack')
:param from_: an entity to merge into another entity
:param to: an entity to merge another entity into
:param foreign_keys: A sequence of foreign keys. By default this is None
indicating all referencing foreign keys should be used.
.. seealso: :func:`dependent_objects`
.. versionadded: 0.26.1
.. versionchanged: 0.40.0
Removed possibility for old-style synchronize_session merging. Only
SQL based merging supported for now.
"""
if from_.__tablename__ != to.__tablename__:
raise TypeError('The tables of given arguments do not match.')
session = object_session(from_)
foreign_keys = get_referencing_foreign_keys(from_)
for fk in foreign_keys:
old_values = get_foreign_key_values(fk, from_)
new_values = get_foreign_key_values(fk, to)
criteria = (
getattr(fk.constraint.table.c, key.key) == value
for key, value in old_values.items()
)
query = (
fk.constraint.table.update()
.where(sa.and_(*criteria))
.values(
{key.key: value for key, value in new_values.items()}
)
)
session.execute(query)
def dependent_objects(obj, foreign_keys=None):
"""
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
through all dependent objects for given SQLAlchemy object.
Consider a User object is referenced in various articles and also in
various orders. Getting all these dependent objects is as easy as::
from sqlalchemy_utils import dependent_objects
dependent_objects(user)
If you expect an object to have lots of dependent_objects it might be good
to limit the results::
dependent_objects(user).limit(5)
The common use case is checking for all restrict dependent objects before
deleting parent object and inform the user if there are dependent objects
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
it will lead to nasty IntegrityErrors being raised.
In the following example we delete given user if it doesn't have any
foreign key restricted dependent objects::
from sqlalchemy_utils import get_referencing_foreign_keys
user = session.query(User).get(some_user_id)
deps = list(
dependent_objects(
user,
(
fk for fk in get_referencing_foreign_keys(User)
# On most databases RESTRICT is the default mode hence we
# check for None values also
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
)
).limit(5)
)
if deps:
# Do something to inform the user
pass
else:
session.delete(user)
:param obj: SQLAlchemy declarative model object
:param foreign_keys:
A sequence of foreign keys to use for searching the dependent_objects
for given object. By default this is None, indicating that all foreign
keys referencing the object will be used.
.. note::
This function does not support exotic mappers that use multiple tables
.. seealso:: :func:`get_referencing_foreign_keys`
.. seealso:: :func:`merge_references`
.. versionadded: 0.26.0
"""
if foreign_keys is None:
foreign_keys = get_referencing_foreign_keys(obj)
session = object_session(obj)
chain = QueryChain([])
classes = _get_class_registry(obj.__class__)
for table, keys in group_foreign_keys(foreign_keys):
keys = list(keys)
for class_ in classes.values():
try:
mapper = sa.inspect(class_)
except NoInspectionAvailable:
continue
parent_mapper = mapper.inherits
if (
table in mapper.tables and
not (parent_mapper and table in parent_mapper.tables)
):
query = session.query(class_).filter(
sa.or_(*_get_criteria(keys, class_, obj))
)
chain.queries.append(query)
return chain
def _get_criteria(keys, class_, obj):
criteria = []
visited_constraints = []
for key in keys:
if key.constraint in visited_constraints:
continue
visited_constraints.append(key.constraint)
subcriteria = []
for index, column in enumerate(key.constraint.columns):
foreign_column = (
key.constraint.elements[index].column
)
subcriteria.append(
getattr(class_, get_column_key(class_, column)) ==
getattr(
obj,
sa.inspect(type(obj))
.get_property_by_column(
foreign_column
).key
)
)
criteria.append(sa.and_(*subcriteria))
return criteria
def non_indexed_foreign_keys(metadata, engine=None):
"""
Finds all non indexed foreign keys from all tables of given MetaData.
Very useful for optimizing postgresql database and finding out which
foreign keys need indexes.
:param metadata: MetaData object to inspect tables from
"""
reflected_metadata = MetaData()
bind = getattr(metadata, 'bind', None)
if bind is None and engine is None:
raise Exception(
'Either pass a metadata object with bind or '
'pass engine as a second parameter'
)
constraints = defaultdict(list)
for table_name in metadata.tables.keys():
table = Table(
table_name,
reflected_metadata,
autoload_with=bind or engine
)
for constraint in table.constraints:
if not isinstance(constraint, ForeignKeyConstraint):
continue
if not has_index(constraint):
constraints[table.name].append(constraint)
return dict(constraints)
def get_fk_constraint_for_columns(table, *columns):
for constraint in table.constraints:
if list(constraint.columns.values()) == list(columns):
return constraint

View file

@ -0,0 +1,112 @@
import contextlib
import datetime
import inspect
import io
import re
import sqlalchemy as sa
def create_mock_engine(bind, stream=None):
"""Create a mock SQLAlchemy engine from the passed engine or bind URL.
:param bind: A SQLAlchemy engine or bind URL to mock.
:param stream: Render all DDL operations to the stream.
"""
if not isinstance(bind, str):
bind_url = str(bind.url)
else:
bind_url = bind
if stream is not None:
def dump(sql, *args, **kwargs):
class Compiler(type(sql._compiler(engine.dialect))):
def visit_bindparam(self, bindparam, *args, **kwargs):
return self.render_literal_value(
bindparam.value, bindparam.type)
def render_literal_value(self, value, type_):
if isinstance(value, int):
return str(value)
elif isinstance(value, (datetime.date, datetime.datetime)):
return "'%s'" % value
return super().render_literal_value(
value, type_)
text = str(Compiler(engine.dialect, sql).process(sql))
text = re.sub(r'\n+', '\n', text)
text = text.strip('\n').strip()
stream.write('\n%s;' % text)
else:
def dump(*args, **kw):
return None
try:
engine = sa.create_mock_engine(bind_url, executor=dump)
except AttributeError: # SQLAlchemy <1.4
engine = sa.create_engine(bind_url, strategy='mock', executor=dump)
return engine
@contextlib.contextmanager
def mock_engine(engine, stream=None):
"""Mocks out the engine specified in the passed bind expression.
Note this function is meant for convenience and protected usage. Do NOT
blindly pass user input to this function as it uses exec.
:param engine: A python expression that represents the engine to mock.
:param stream: Render all DDL operations to the stream.
"""
# Create a stream if not present.
if stream is None:
stream = io.StringIO()
# Navigate the stack and find the calling frame that allows the
# expression to execute.
for frame in inspect.stack()[1:]:
try:
frame = frame[0]
expression = '__target = %s' % engine
exec(expression, frame.f_globals, frame.f_locals)
target = frame.f_locals['__target']
break
except Exception:
pass
else:
raise ValueError('Not a valid python expression', engine)
# Evaluate the expression and get the target engine.
frame.f_locals['__mock'] = create_mock_engine(target, stream)
# Replace the target with our mock.
exec('%s = __mock' % engine, frame.f_globals, frame.f_locals)
# Give control back.
yield stream
# Put the target engine back.
frame.f_locals['__target'] = target
exec('%s = __target' % engine, frame.f_globals, frame.f_locals)
exec('del __target', frame.f_globals, frame.f_locals)
exec('del __mock', frame.f_globals, frame.f_locals)

View file

@ -0,0 +1,904 @@
from collections import OrderedDict
from functools import partial
from inspect import isclass
from operator import attrgetter
import sqlalchemy as sa
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.orm.exc import UnmappedInstanceError
try:
from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
except ImportError: # SQLAlchemy <1.4
from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
from sqlalchemy.orm.session import object_session
from sqlalchemy.orm.util import AliasedInsp
from ..utils import is_sequence
def get_class_by_table(base, table, data=None):
"""
Return declarative class associated with given table. If no class is found
this function returns `None`. If multiple classes were found (polymorphic
cases) additional `data` parameter can be given to hint which class
to return.
::
class User(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
get_class_by_table(Base, User.__table__) # User class
This function also supports models using single table inheritance.
Additional data paratemer should be provided in these case.
::
class Entity(Base):
__tablename__ = 'entity'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
type = sa.Column(sa.String)
__mapper_args__ = {
'polymorphic_on': type,
'polymorphic_identity': 'entity'
}
class User(Entity):
__mapper_args__ = {
'polymorphic_identity': 'user'
}
# Entity class
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
# User class
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
:param base: Declarative model base
:param table: SQLAlchemy Table object
:param data: Data row to determine the class in polymorphic scenarios
:return: Declarative class or None.
"""
found_classes = {
c for c in _get_class_registry(base).values()
if hasattr(c, '__table__') and c.__table__ is table
}
if len(found_classes) > 1:
if not data:
raise ValueError(
"Multiple declarative classes found for table '{}'. "
"Please provide data parameter for this function to be able "
"to determine polymorphic scenarios.".format(
table.name
)
)
else:
for cls in found_classes:
mapper = sa.inspect(cls)
polymorphic_on = mapper.polymorphic_on.name
if polymorphic_on in data:
if data[polymorphic_on] == mapper.polymorphic_identity:
return cls
raise ValueError(
"Multiple declarative classes found for table '{}'. Given "
"data row does not match any polymorphic identity of the "
"found classes.".format(
table.name
)
)
elif found_classes:
return found_classes.pop()
return None
def get_type(expr):
"""
Return the associated type with given Column, InstrumentedAttribute,
ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
For constructs wrapping columns this is the column type. For relationships
this function returns the relationship mapper class.
:param expr:
SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
similar SA construct.
::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
author = sa.orm.relationship(User)
get_type(User.__table__.c.name) # sa.String()
get_type(User.name) # sa.String()
get_type(User.name.property) # sa.String()
get_type(Article.author) # User
.. versionadded: 0.30.9
"""
if hasattr(expr, 'type'):
return expr.type
elif isinstance(expr, InstrumentedAttribute):
expr = expr.property
if isinstance(expr, ColumnProperty):
return expr.columns[0].type
elif isinstance(expr, RelationshipProperty):
return expr.mapper.class_
raise TypeError("Couldn't inspect type.")
def cast_if(expression, type_):
"""
Produce a CAST expression but only if given expression is not of given type
already.
Assume we have a model with two fields id (Integer) and name (String).
::
import sqlalchemy as sa
from sqlalchemy_utils import cast_if
cast_if(User.id, sa.Integer) # "user".id
cast_if(User.name, sa.String) # "user".name
cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
This function supports scalar values as well.
::
cast_if(1, sa.Integer) # 1
cast_if('text', sa.String) # 'text'
cast_if(1, sa.String) # CAST(1 AS TEXT)
:param expression:
A SQL expression, such as a ColumnElement expression or a Python string
which will be coerced into a bound literal value.
:param type_:
A TypeEngine class or instance indicating the type to which the CAST
should apply.
.. versionadded: 0.30.14
"""
try:
expr_type = get_type(expression)
except TypeError:
expr_type = expression
check_type = type_().python_type
else:
check_type = type_
return (
sa.cast(expression, type_)
if not isinstance(expr_type, check_type)
else expression
)
def get_column_key(model, column):
"""
Return the key for given column in given model.
:param model: SQLAlchemy declarative model object
::
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column('_name', sa.String)
get_column_key(User, User.__table__.c._name) # 'name'
.. versionadded: 0.26.5
.. versionchanged: 0.27.11
Throws UnmappedColumnError instead of ValueError when no property was
found for given column. This is consistent with how SQLAlchemy works.
"""
mapper = sa.inspect(model)
try:
return mapper.get_property_by_column(column).key
except sa.orm.exc.UnmappedColumnError:
for key, c in mapper.columns.items():
if c.name == column.name and c.table is column.table:
return key
raise sa.orm.exc.UnmappedColumnError(
'No column %s is configured on mapper %s...' %
(column, mapper)
)
def get_mapper(mixed):
"""
Return related SQLAlchemy Mapper for given SQLAlchemy object.
:param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
::
from sqlalchemy_utils import get_mapper
get_mapper(User)
get_mapper(User())
get_mapper(User.__table__)
get_mapper(User.__mapper__)
get_mapper(sa.orm.aliased(User))
get_mapper(sa.orm.aliased(User.__table__))
Raises:
ValueError: if multiple mappers were found for given argument
.. versionadded: 0.26.1
"""
if isinstance(mixed, _MapperEntity):
mixed = mixed.expr
elif isinstance(mixed, sa.Column):
mixed = mixed.table
elif isinstance(mixed, _ColumnEntity):
mixed = mixed.expr
if isinstance(mixed, sa.orm.Mapper):
return mixed
if isinstance(mixed, sa.orm.util.AliasedClass):
return sa.inspect(mixed).mapper
if isinstance(mixed, sa.sql.selectable.Alias):
mixed = mixed.element
if isinstance(mixed, AliasedInsp):
return mixed.mapper
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
mixed = mixed.class_
if isinstance(mixed, sa.Table):
if hasattr(mapperlib, '_all_registries'):
all_mappers = set()
for mapper_registry in mapperlib._all_registries():
all_mappers.update(mapper_registry.mappers)
else: # SQLAlchemy <1.4
all_mappers = mapperlib._mapper_registry
mappers = [
mapper for mapper in all_mappers
if mixed in mapper.tables
]
if len(mappers) > 1:
raise ValueError(
"Multiple mappers found for table '%s'." % mixed.name
)
elif not mappers:
raise ValueError(
"Could not get mapper for table '%s'." % mixed.name
)
else:
return mappers[0]
if not isclass(mixed):
mixed = type(mixed)
return sa.inspect(mixed)
def get_bind(obj):
"""
Return the bind for given SQLAlchemy Engine / Connection / declarative
model object.
:param obj: SQLAlchemy Engine / Connection / declarative model object
::
from sqlalchemy_utils import get_bind
get_bind(session) # Connection object
get_bind(user)
"""
if hasattr(obj, 'bind'):
conn = obj.bind
else:
try:
conn = object_session(obj).bind
except UnmappedInstanceError:
conn = obj
if not hasattr(conn, 'execute'):
raise TypeError(
'This method accepts only Session, Engine, Connection and '
'declarative model objects.'
)
return conn
def get_primary_keys(mixed):
"""
Return an OrderedDict of all primary keys for given Table object,
declarative class or declarative class instance.
:param mixed:
SA Table object, SA declarative class or SA declarative class instance
::
get_primary_keys(User)
get_primary_keys(User())
get_primary_keys(User.__table__)
get_primary_keys(User.__mapper__)
get_primary_keys(sa.orm.aliased(User))
get_primary_keys(sa.orm.aliased(User.__table__))
.. versionchanged: 0.25.3
Made the function return an ordered dictionary instead of generator.
This change was made to support primary key aliases.
Renamed this function to 'get_primary_keys', formerly 'primary_keys'
.. seealso:: :func:`get_columns`
"""
return OrderedDict(
(
(key, column) for key, column in get_columns(mixed).items()
if column.primary_key
)
)
def get_tables(mixed):
"""
Return a set of tables associated with given SQLAlchemy object.
Let's say we have three classes which use joined table inheritance
TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
::
get_tables(Article) # set([Table('article', ...), Table('text_item')])
get_tables(Article())
get_tables(Article.__mapper__)
If the TextItem entity is using with_polymorphic='*' then this function
returns all child tables (article and blog_post) as well.
::
get_tables(TextItem) # set([Table('text_item', ...)], ...])
.. versionadded: 0.26.0
:param mixed:
SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
a SA Alias object wrapping any of these objects.
"""
if isinstance(mixed, sa.Table):
return [mixed]
elif isinstance(mixed, sa.Column):
return [mixed.table]
elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
return mixed.parent.tables
elif isinstance(mixed, _ColumnEntity):
mixed = mixed.expr
mapper = get_mapper(mixed)
polymorphic_mappers = get_polymorphic_mappers(mapper)
if polymorphic_mappers:
tables = sum((m.tables for m in polymorphic_mappers), [])
else:
tables = mapper.tables
return tables
def get_columns(mixed):
"""
Return a collection of all Column objects for given SQLAlchemy
object.
The type of the collection depends on the type of the object to return the
columns from.
::
get_columns(User)
get_columns(User())
get_columns(User.__table__)
get_columns(User.__mapper__)
get_columns(sa.orm.aliased(User))
get_columns(sa.orm.alised(User.__table__))
:param mixed:
SA Table object, SA Mapper, SA declarative class, SA declarative class
instance or an alias of any of these objects
"""
if isinstance(mixed, sa.sql.selectable.Selectable):
try:
return mixed.selected_columns
except AttributeError: # SQLAlchemy <1.4
return mixed.c
if isinstance(mixed, sa.orm.util.AliasedClass):
return sa.inspect(mixed).mapper.columns
if isinstance(mixed, sa.orm.Mapper):
return mixed.columns
if isinstance(mixed, InstrumentedAttribute):
return mixed.property.columns
if isinstance(mixed, ColumnProperty):
return mixed.columns
if isinstance(mixed, sa.Column):
return [mixed]
if not isclass(mixed):
mixed = mixed.__class__
return sa.inspect(mixed).columns
def table_name(obj):
"""
Return table name of given target, declarative class or the
table name where the declarative attribute is bound to.
"""
class_ = getattr(obj, 'class_', obj)
try:
return class_.__tablename__
except AttributeError:
pass
try:
return class_.__table__.name
except AttributeError:
pass
def getattrs(obj, attrs):
return map(partial(getattr, obj), attrs)
def quote(mixed, ident):
"""
Conditionally quote an identifier.
::
from sqlalchemy_utils import quote
engine = create_engine('sqlite:///:memory:')
quote(engine, 'order')
# '"order"'
quote(engine, 'some_other_identifier')
# 'some_other_identifier'
:param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
:param ident: identifier to conditionally quote
"""
if isinstance(mixed, Dialect):
dialect = mixed
else:
dialect = get_bind(mixed).dialect
return dialect.preparer(dialect).quote(ident)
def _get_query_compile_state(query):
if hasattr(query, '_compile_state'):
return query._compile_state()
else: # SQLAlchemy <1.4
return query
def get_polymorphic_mappers(mixed):
if isinstance(mixed, AliasedInsp):
return mixed.with_polymorphic_mappers
else:
return mixed.polymorphic_map.values()
def get_descriptor(entity, attr):
mapper = sa.inspect(entity)
for key, descriptor in get_all_descriptors(mapper).items():
if attr == key:
prop = (
descriptor.property
if hasattr(descriptor, 'property')
else None
)
if isinstance(prop, ColumnProperty):
if isinstance(entity, sa.orm.util.AliasedClass):
for c in mapper.selectable.c:
if c.key == attr:
return c
else:
# If the property belongs to a class that uses
# polymorphic inheritance we have to take into account
# situations where the attribute exists in child class
# but not in parent class.
return getattr(prop.parent.class_, attr)
else:
# Handle synonyms, relationship properties and hybrid
# properties
if isinstance(entity, sa.orm.util.AliasedClass):
return getattr(entity, attr)
try:
return getattr(mapper.class_, attr)
except AttributeError:
pass
def get_all_descriptors(expr):
if isinstance(expr, sa.sql.selectable.Selectable):
return expr.c
insp = sa.inspect(expr)
try:
polymorphic_mappers = get_polymorphic_mappers(insp)
except sa.exc.NoInspectionAvailable:
return get_mapper(expr).all_orm_descriptors
else:
attrs = dict(get_mapper(expr).all_orm_descriptors)
for submapper in polymorphic_mappers:
for key, descriptor in submapper.all_orm_descriptors.items():
if key not in attrs:
attrs[key] = descriptor
return attrs
def get_hybrid_properties(model):
"""
Returns a dictionary of hybrid property keys and hybrid properties for
given SQLAlchemy declarative model / mapper.
Consider the following model
::
from sqlalchemy.ext.hybrid import hybrid_property
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
@hybrid_property
def lowercase_name(self):
return self.name.lower()
@lowercase_name.expression
def lowercase_name(cls):
return sa.func.lower(cls.name)
You can now easily get a list of all hybrid property names
::
from sqlalchemy_utils import get_hybrid_properties
get_hybrid_properties(Category).keys() # ['lowercase_name']
This function also supports aliased classes
::
get_hybrid_properties(
sa.orm.aliased(Category)
).keys() # ['lowercase_name']
.. versionchanged: 0.26.7
This function now returns a dictionary instead of generator
.. versionchanged: 0.30.15
Added support for aliased classes
:param model: SQLAlchemy declarative model or mapper
"""
return {
key: prop
for key, prop in get_mapper(model).all_orm_descriptors.items()
if isinstance(prop, hybrid_property)
}
def get_declarative_base(model):
"""
Returns the declarative base for given model class.
:param model: SQLAlchemy declarative model
"""
for parent in model.__bases__:
try:
parent.metadata
return get_declarative_base(parent)
except AttributeError:
pass
return model
def getdotattr(obj_or_class, dot_path, condition=None):
"""
Allow dot-notated strings to be passed to `getattr`.
::
getdotattr(SubSection, 'section.document')
getdotattr(subsection, 'section.document')
:param obj_or_class: Any object or class
:param dot_path: Attribute path with dot mark as separator
"""
last = obj_or_class
for path in str(dot_path).split('.'):
getter = attrgetter(path)
if is_sequence(last):
tmp = []
for element in last:
value = getter(element)
if is_sequence(value):
tmp.extend(value)
else:
tmp.append(value)
last = tmp
elif isinstance(last, InstrumentedAttribute):
last = getter(last.property.mapper.class_)
elif last is None:
return None
else:
last = getter(last)
if condition is not None:
if is_sequence(last):
last = [v for v in last if condition(v)]
else:
if not condition(last):
return None
return last
def is_deleted(obj):
return obj in sa.orm.object_session(obj).deleted
def has_changes(obj, attrs=None, exclude=None):
"""
Simple shortcut function for checking if given attributes of given
declarative model object have changed during the session. Without
parameters this checks if given object has any modificiations. Additionally
exclude parameter can be given to check if given object has any changes
in any attributes other than the ones given in exclude.
::
from sqlalchemy_utils import has_changes
user = User()
has_changes(user, 'name') # False
user.name = 'someone'
has_changes(user, 'name') # True
has_changes(user) # True
You can check multiple attributes as well.
::
has_changes(user, ['age']) # True
has_changes(user, ['name', 'age']) # True
This function also supports excluding certain attributes.
::
has_changes(user, exclude=['name']) # False
has_changes(user, exclude=['age']) # True
.. versionchanged: 0.26.6
Added support for multiple attributes and exclude parameter.
:param obj: SQLAlchemy declarative model object
:param attrs: Names of the attributes
:param exclude: Names of the attributes to exclude
"""
if attrs:
if isinstance(attrs, str):
return (
sa.inspect(obj)
.attrs
.get(attrs)
.history
.has_changes()
)
else:
return any(has_changes(obj, attr) for attr in attrs)
else:
if exclude is None:
exclude = []
return any(
attr.history.has_changes()
for key, attr in sa.inspect(obj).attrs.items()
if key not in exclude
)
def is_loaded(obj, prop):
"""
Return whether or not given property of given object has been loaded.
::
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
content = sa.orm.deferred(sa.Column(sa.String))
article = session.query(Article).get(5)
# name gets loaded since its not a deferred property
assert is_loaded(article, 'name')
# content has not yet been loaded since its a deferred property
assert not is_loaded(article, 'content')
.. versionadded: 0.27.8
:param obj: SQLAlchemy declarative model object
:param prop: Name of the property or InstrumentedAttribute
"""
return prop not in sa.inspect(obj).unloaded
def identity(obj_or_class):
"""
Return the identity of given sqlalchemy declarative model class or instance
as a tuple. This differs from obj._sa_instance_state.identity in a way that
it always returns the identity even if object is still in transient state (
new object that is not yet persisted into database). Also for classes it
returns the identity attributes.
::
from sqlalchemy import inspect
from sqlalchemy_utils import identity
user = User(name='John Matrix')
session.add(user)
identity(user) # None
inspect(user).identity # None
session.flush() # User now has id but is still in transient state
identity(user) # (1,)
inspect(user).identity # None
session.commit()
identity(user) # (1,)
inspect(user).identity # (1, )
You can also use identity for classes::
identity(User) # (User.id, )
.. versionadded: 0.21.0
:param obj: SQLAlchemy declarative model object
"""
return tuple(
getattr(obj_or_class, column_key)
for column_key in get_primary_keys(obj_or_class).keys()
)
def naturally_equivalent(obj, obj2):
"""
Returns whether or not two given SQLAlchemy declarative instances are
naturally equivalent (all their non primary key properties are equivalent).
::
from sqlalchemy_utils import naturally_equivalent
user = User(name='someone')
user2 = User(name='someone')
user == user2 # False
naturally_equivalent(user, user2) # True
:param obj: SQLAlchemy declarative model object
:param obj2: SQLAlchemy declarative model object to compare with `obj`
"""
for column_key, column in sa.inspect(obj.__class__).columns.items():
if column.primary_key:
continue
if not (getattr(obj, column_key) == getattr(obj2, column_key)):
return False
return True
def _get_class_registry(class_):
try:
return class_.registry._class_registry
except AttributeError: # SQLAlchemy <1.4
return class_._decl_class_registry

View file

@ -0,0 +1,75 @@
import inspect
import io
import sqlalchemy as sa
from .mock import create_mock_engine
from .orm import _get_query_compile_state
def render_expression(expression, bind, stream=None):
"""Generate a SQL expression from the passed python expression.
Only the global variable, `engine`, is available for use in the
expression. Additional local variables may be passed in the context
parameter.
Note this function is meant for convenience and protected usage. Do NOT
blindly pass user input to this function as it uses exec.
:param bind: A SQLAlchemy engine or bind URL.
:param stream: Render all DDL operations to the stream.
"""
# Create a stream if not present.
if stream is None:
stream = io.StringIO()
engine = create_mock_engine(bind, stream)
# Navigate the stack and find the calling frame that allows the
# expression to execuate.
for frame in inspect.stack()[1:]:
try:
frame = frame[0]
local = dict(frame.f_locals)
local['engine'] = engine
exec(expression, frame.f_globals, local)
break
except Exception:
pass
else:
raise ValueError('Not a valid python expression', engine)
return stream
def render_statement(statement, bind=None):
"""
Generate an SQL expression string with bound parameters rendered inline
for the given SQLAlchemy statement.
:param statement: SQLAlchemy Query object.
:param bind:
Optional SQLAlchemy bind, if None uses the bind of the given query
object.
"""
if isinstance(statement, sa.orm.query.Query):
if bind is None:
bind = statement.session.get_bind(
_get_query_compile_state(statement)._mapper_zero()
)
statement = statement.statement
elif bind is None:
bind = statement.bind
stream = io.StringIO()
engine = create_mock_engine(bind.engine, stream=stream)
engine.execute(statement)
return stream.getvalue()

View file

@ -0,0 +1,74 @@
import sqlalchemy as sa
from .database import has_unique_index
from .orm import _get_query_compile_state, get_tables
def make_order_by_deterministic(query):
"""
Make query order by deterministic (if it isn't already). Order by is
considered deterministic if it contains column that is unique index (
either it is a primary key or has a unique index). Many times it is design
flaw to order by queries in nondeterministic manner.
Consider a User model with three fields: id (primary key), favorite color
and email (unique).::
from sqlalchemy_utils import make_order_by_deterministic
query = session.query(User).order_by(User.favorite_color)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id'
query = session.query(User).order_by(User.email)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".email'
query = session.query(User).order_by(User.id)
query = make_order_by_deterministic(query)
print query # 'SELECT ... ORDER BY "user".id'
.. versionadded: 0.27.1
"""
order_by_func = sa.asc
try:
order_by_clauses = query._order_by_clauses
except AttributeError: # SQLAlchemy <1.4
order_by_clauses = query._order_by
if not order_by_clauses:
column = None
else:
order_by = order_by_clauses[0]
if isinstance(order_by, sa.sql.elements._label_reference):
order_by = order_by.element
if isinstance(order_by, sa.sql.expression.UnaryExpression):
if order_by.modifier == sa.sql.operators.desc_op:
order_by_func = sa.desc
else:
order_by_func = sa.asc
column = list(order_by.get_children())[0]
else:
column = order_by
# Skip queries that are ordered by an already deterministic column
if isinstance(column, sa.Column):
try:
if has_unique_index(column):
return query
except TypeError:
pass
base_table = get_tables(_get_query_compile_state(query)._entities[0])[0]
query = query.order_by(
*(order_by_func(c) for c in base_table.c if c.primary_key)
)
return query

View file

@ -0,0 +1,185 @@
from collections.abc import Iterable
import sqlalchemy as sa
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
from sqlalchemy.orm.session import _state_session
from sqlalchemy.util import set_creation_order
from .exceptions import ImproperlyConfigured
from .functions import identity
from .functions.orm import _get_class_registry
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
if self.key in dict_:
return dict_[self.key]
# Retrieve the session bound to the state in order to perform
# a lazy query for the attribute.
session = _state_session(state)
if session is None:
# State is not bound to a session; we cannot proceed.
return None
# Find class for discriminator.
# TODO: Perhaps optimize with some sort of lookup?
discriminator = self.get_state_discriminator(state)
target_class = _get_class_registry(state.class_).get(discriminator)
if target_class is None:
# Unknown discriminator; return nothing.
return None
id = self.get_state_id(state)
try:
target = session.get(target_class, id)
except AttributeError:
# sqlalchemy 1.3
target = session.query(target_class).get(id)
# Return found (or not found) target.
return target
def get_state_discriminator(self, state):
discriminator = self.parent_token.discriminator
if isinstance(discriminator, hybrid_property):
return getattr(state.obj(), discriminator.__name__)
else:
return state.attrs[discriminator.key].value
def get_state_id(self, state):
# Lookup row with the discriminator and id.
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
def set(self, state, dict_, initiator,
passive=attributes.PASSIVE_OFF,
check_old=None,
pop=False):
# Set us on the state.
dict_[self.key] = initiator
if initiator is None:
# Nullify relationship args
for id in self.parent_token.id:
dict_[id.key] = None
dict_[self.parent_token.discriminator.key] = None
else:
# Get the primary key of the initiator and ensure we
# can support this assignment.
class_ = type(initiator)
mapper = class_mapper(class_)
pk = mapper.identity_key_from_instance(initiator)[1]
# Set the identifier and the discriminator.
discriminator = class_.__name__
for index, id in enumerate(self.parent_token.id):
dict_[id.key] = pk[index]
dict_[self.parent_token.discriminator.key] = discriminator
class GenericRelationshipProperty(MapperProperty):
"""A generic form of the relationship property.
Creates a 1 to many relationship between the parent model
and any other models using a descriminator (the table name).
:param discriminator
Field to discriminate which model we are referring to.
:param id:
Field to point to the model we are referring to.
"""
def __init__(self, discriminator, id, doc=None):
super().__init__()
self._discriminator_col = discriminator
self._id_cols = id
self._id = None
self._discriminator = None
self.doc = doc
set_creation_order(self)
def _column_to_property(self, column):
if isinstance(column, hybrid_property):
attr_key = column.__name__
for key, attr in self.parent.all_orm_descriptors.items():
if key == attr_key:
return attr
else:
for attr in self.parent.attrs.values():
if isinstance(attr, ColumnProperty):
if attr.columns[0].name == column.name:
return attr
def init(self):
def convert_strings(column):
if isinstance(column, str):
return self.parent.columns[column]
return column
self._discriminator_col = convert_strings(self._discriminator_col)
self._id_cols = convert_strings(self._id_cols)
if isinstance(self._id_cols, Iterable):
self._id_cols = list(map(convert_strings, self._id_cols))
else:
self._id_cols = [self._id_cols]
self.discriminator = self._column_to_property(self._discriminator_col)
if self.discriminator is None:
raise ImproperlyConfigured(
'Could not find discriminator descriptor.'
)
self.id = list(map(self._column_to_property, self._id_cols))
class Comparator(PropComparator):
def __init__(self, prop, parentmapper):
self.property = prop
self._parententity = parentmapper
def __eq__(self, other):
discriminator = type(other).__name__
q = self.property._discriminator_col == discriminator
other_id = identity(other)
for index, id in enumerate(self.property._id_cols):
q &= id == other_id[index]
return q
def __ne__(self, other):
return ~(self == other)
def is_type(self, other):
mapper = sa.inspect(other)
# Iterate through the weak sequence in order to get the actual
# mappers
class_names = [other.__name__]
class_names.extend([
submapper.class_.__name__
for submapper in mapper._inheriting_mappers
])
return self.property._discriminator_col.in_(class_names)
def instrument_class(self, mapper):
attributes.register_attribute(
mapper.class_,
self.key,
comparator=self.Comparator(self, mapper),
parententity=mapper,
doc=self.doc,
impl_class=GenericAttributeImpl,
parent_token=self
)
def generic_relationship(*args, **kwargs):
return GenericRelationshipProperty(*args, **kwargs)

View file

@ -0,0 +1,119 @@
import sqlalchemy as sa
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.ext.hybrid import hybrid_property
from sqlalchemy.sql.expression import ColumnElement
from .exceptions import ImproperlyConfigured
try:
import babel
import babel.dates
except ImportError:
babel = None
def get_locale():
try:
return babel.Locale('en')
except AttributeError:
# As babel is optional, we may raise an AttributeError accessing it
raise ImproperlyConfigured(
'Could not load get_locale function using Babel. Either '
'install Babel or make a similar function and override it '
'in this module.'
)
def cast_locale(obj, locale, attr):
"""
Cast given locale to string. Supports also callbacks that return locales.
:param obj:
Object or class to use as a possible parameter to locale callable
:param locale:
Locale object or string or callable that returns a locale.
"""
if callable(locale):
try:
locale = locale(obj, attr.key)
except TypeError:
try:
locale = locale(obj)
except TypeError:
locale = locale()
if isinstance(locale, babel.Locale):
return str(locale)
return locale
class cast_locale_expr(ColumnElement):
inherit_cache = False
def __init__(self, cls, locale, attr):
self.cls = cls
self.locale = locale
self.attr = attr
@compiles(cast_locale_expr)
def compile_cast_locale_expr(element, compiler, **kw):
locale = cast_locale(element.cls, element.locale, element.attr)
if isinstance(locale, str):
return f"'{locale}'"
return compiler.process(locale)
class TranslationHybrid:
def __init__(self, current_locale, default_locale, default_value=None):
if babel is None:
raise ImproperlyConfigured(
'You need to install babel in order to use TranslationHybrid.'
)
self.current_locale = current_locale
self.default_locale = default_locale
self.default_value = default_value
def getter_factory(self, attr):
"""
Return a hybrid_property getter function for given attribute. The
returned getter first checks if object has translation for current
locale. If not it tries to get translation for default locale. If there
is no translation found for default locale it returns None.
"""
def getter(obj):
current_locale = cast_locale(obj, self.current_locale, attr)
try:
return getattr(obj, attr.key)[current_locale]
except (TypeError, KeyError):
default_locale = cast_locale(obj, self.default_locale, attr)
try:
return getattr(obj, attr.key)[default_locale]
except (TypeError, KeyError):
return self.default_value
return getter
def setter_factory(self, attr):
def setter(obj, value):
if getattr(obj, attr.key) is None:
setattr(obj, attr.key, {})
locale = cast_locale(obj, self.current_locale, attr)
getattr(obj, attr.key)[locale] = value
return setter
def expr_factory(self, attr):
def expr(cls):
cls_attr = getattr(cls, attr.key)
current_locale = cast_locale_expr(cls, self.current_locale, attr)
default_locale = cast_locale_expr(cls, self.default_locale, attr)
return sa.func.coalesce(
cls_attr[current_locale],
cls_attr[default_locale]
)
return expr
def __call__(self, attr):
return hybrid_property(
fget=self.getter_factory(attr),
fset=self.setter_factory(attr),
expr=self.expr_factory(attr)
)

View file

@ -0,0 +1,277 @@
import sqlalchemy as sa
from .exceptions import ImproperlyConfigured
def coercion_listener(mapper, class_):
"""
Auto assigns coercing listener for all class properties which are of coerce
capable type.
"""
for prop in mapper.iterate_properties:
try:
listener = prop.columns[0].type.coercion_listener
except AttributeError:
continue
sa.event.listen(
getattr(class_, prop.key),
'set',
listener,
retval=True
)
def instant_defaults_listener(target, args, kwargs):
# insertion order of kwargs matters
# copy and clear so that we can add back later at the end of the dict
original = kwargs.copy()
kwargs.clear()
for key, column in sa.inspect(target.__class__).columns.items():
if (
hasattr(column, 'default') and
column.default is not None
):
if callable(column.default.arg):
kwargs[key] = column.default.arg(target)
else:
kwargs[key] = column.default.arg
# supersede w/initial in case target uses setters overriding defaults
kwargs.update(original)
def force_auto_coercion(mapper=None):
"""
Function that assigns automatic data type coercion for all classes which
are of type of given mapper. The coercion is applied to all coercion
capable properties. By default coercion is applied to all SQLAlchemy
mappers.
Before initializing your models you need to call force_auto_coercion.
::
from sqlalchemy_utils import force_auto_coercion
force_auto_coercion()
Then define your models the usual way::
class Document(Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(50))
background_color = sa.Column(ColorType)
Now scalar values for coercion capable data types will convert to
appropriate value objects::
document = Document()
document.background_color = 'F5F5F5'
document.background_color # Color object
session.commit()
A useful side effect of this is that additional validation of data will be
done on the moment it is being assigned to model objects. For example
without autocorrection set, an invalid
:class:`sqlalchemy_utils.types.IPAddressType` (eg. ``10.0.0 255.255``)
would get through without an exception being raised. The database wouldn't
notice this (as most databases don't have a native type for an IP address,
so they're usually just stored as a string), and the ``ipaddress``
package uses a string field as well.
:param mapper: The mapper which the automatic data type coercion should be
applied to
"""
if mapper is None:
mapper = sa.orm.Mapper
sa.event.listen(mapper, 'mapper_configured', coercion_listener)
def force_instant_defaults(mapper=None):
"""
Function that assigns object column defaults on object initialization
time. By default calling this function applies instant defaults to all
your models.
Setting up instant defaults::
from sqlalchemy_utils import force_instant_defaults
force_instant_defaults()
Example usage::
class Document(Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(50))
created_at = sa.Column(sa.DateTime, default=datetime.now)
document = Document()
document.created_at # datetime object
:param mapper: The mapper which the automatic instant defaults forcing
should be applied to
"""
if mapper is None:
mapper = sa.orm.Mapper
sa.event.listen(mapper, 'init', instant_defaults_listener)
def auto_delete_orphans(attr):
"""
Delete orphans for given SQLAlchemy model attribute. This function can be
used for deleting many-to-many associated orphans easily. For more
information see
https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/ManyToManyOrphan.
Consider the following model definition:
::
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy import *
from sqlalchemy.orm import *
# Necessary in sqlalchemy 1.3:
# from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy import event
Base = declarative_base()
tagging = Table(
'tagging',
Base.metadata,
Column(
'tag_id',
Integer,
ForeignKey('tag.id', ondelete='CASCADE'),
primary_key=True
),
Column(
'entry_id',
Integer,
ForeignKey('entry.id', ondelete='CASCADE'),
primary_key=True
)
)
class Tag(Base):
__tablename__ = 'tag'
id = Column(Integer, primary_key=True)
name = Column(String(100), unique=True, nullable=False)
def __init__(self, name=None):
self.name = name
class Entry(Base):
__tablename__ = 'entry'
id = Column(Integer, primary_key=True)
tags = relationship(
'Tag',
secondary=tagging,
backref='entries'
)
Now lets say we want to delete the tags if all their parents get deleted (
all Entry objects get deleted). This can be achieved as follows:
::
from sqlalchemy_utils import auto_delete_orphans
auto_delete_orphans(Entry.tags)
After we've set up this listener we can see it in action.
::
e = create_engine('sqlite://')
Base.metadata.create_all(e)
s = Session(e)
r1 = Entry()
r2 = Entry()
r3 = Entry()
t1, t2, t3, t4 = Tag('t1'), Tag('t2'), Tag('t3'), Tag('t4')
r1.tags.extend([t1, t2])
r2.tags.extend([t2, t3])
r3.tags.extend([t4])
s.add_all([r1, r2, r3])
assert s.query(Tag).count() == 4
r2.tags.remove(t2)
assert s.query(Tag).count() == 4
r1.tags.remove(t2)
assert s.query(Tag).count() == 3
r1.tags.remove(t1)
assert s.query(Tag).count() == 2
.. versionadded: 0.26.4
:param attr: Association relationship attribute to auto delete orphans from
"""
parent_class = attr.parent.class_
target_class = attr.property.mapper.class_
backref = attr.property.backref
if not backref:
raise ImproperlyConfigured(
'The relationship argument given for auto_delete_orphans needs to '
'have a backref relationship set.'
)
if isinstance(backref, tuple):
backref = backref[0]
@sa.event.listens_for(sa.orm.Session, 'after_flush')
def delete_orphan_listener(session, ctx):
# Look through Session state to see if we want to emit a DELETE for
# orphans
orphans_found = (
any(
isinstance(obj, parent_class) and
sa.orm.attributes.get_history(obj, attr.key).deleted
for obj in session.dirty
) or
any(
isinstance(obj, parent_class)
for obj in session.deleted
)
)
if orphans_found:
# Emit a DELETE for all orphans
(
session.query(target_class)
.filter(
~getattr(target_class, backref).any()
)
.delete(synchronize_session=False)
)

View file

@ -0,0 +1,96 @@
from datetime import datetime
import sqlalchemy as sa
class Timestamp:
"""Adds `created` and `updated` columns to a derived declarative model.
The `created` column is handled through a default and the `updated`
column is handled through a `before_update` event that propagates
for all derived declarative models.
::
import sqlalchemy as sa
from sqlalchemy_utils import Timestamp
class SomeModel(Base, Timestamp):
__tablename__ = 'somemodel'
id = sa.Column(sa.Integer, primary_key=True)
"""
created = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False)
updated = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False)
@sa.event.listens_for(Timestamp, 'before_update', propagate=True)
def timestamp_before_update(mapper, connection, target):
# When a model with a timestamp is updated; force update the updated
# timestamp.
target.updated = datetime.utcnow()
NOT_LOADED_REPR = '<not loaded>'
def _generic_repr_method(self, fields):
state = sa.inspect(self)
field_reprs = []
if not fields:
fields = state.mapper.columns.keys()
for key in fields:
value = state.attrs[key].loaded_value
if key in state.unloaded:
value = NOT_LOADED_REPR
else:
value = repr(value)
field_reprs.append('='.join((key, value)))
return '{}({})'.format(self.__class__.__name__, ', '.join(field_reprs))
def generic_repr(*fields):
"""Adds generic ``__repr__()`` method to a declarative SQLAlchemy model.
In case if some fields are not loaded from a database, it doesn't
force their loading and instead repesents them as ``<not loaded>``.
In addition, user can provide field names as arguments to the decorator
to specify what fields should present in the string representation
and in what order.
Example::
import sqlalchemy as sa
from sqlalchemy_utils import generic_repr
@generic_repr
class MyModel(Base):
__tablename__ = 'mymodel'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
category = sa.Column(sa.String)
session.add(MyModel(name='Foo', category='Bar'))
session.commit()
foo = session.query(MyModel).options(sa.orm.defer('category')).one(s)
assert repr(foo) == 'MyModel(id=1, name='Foo', category=<not loaded>)'
"""
if len(fields) == 1 and callable(fields[0]):
target = fields[0]
target.__repr__ = lambda self: _generic_repr_method(self, fields=None)
return target
else:
def decorator(cls):
cls.__repr__ = lambda self: _generic_repr_method(
self,
fields=fields
)
return cls
return decorator

View file

@ -0,0 +1,376 @@
"""
This module provides a decorator function for observing changes in a given
property. Internally the decorator is implemented using SQLAlchemy event
listeners. Both column properties and relationship properties can be observed.
Property observers can be used for pre-calculating aggregates and automatic
real-time data denormalization.
Simple observers
----------------
At the heart of the observer extension is the :func:`observes` decorator. You
mark some property path as being observed and the marked method will get
notified when any changes are made to given path.
Consider the following model structure:
::
class Director(Base):
__tablename__ = 'director'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
date_of_birth = sa.Column(sa.Date)
class Movie(Base):
__tablename__ = 'movie'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String)
director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id))
director = sa.orm.relationship(Director, backref='movies')
Now consider we want to show movies in some listing ordered by director id
first and movie id secondly. If we have many movies then using joins and
ordering by Director.name will be very slow. Here is where denormalization
and :func:`observes` comes to rescue the day. Let's add a new column called
director_name to Movie which will get automatically copied from associated
Director.
::
from sqlalchemy_utils import observes
class Movie(Base):
# same as before..
director_name = sa.Column(sa.String)
@observes('director')
def director_observer(self, director):
self.director_name = director.name
.. note::
This example could be done much more efficiently using a compound foreign
key from director_name, director_id to Director.name, Director.id but for
the sake of simplicity we added this as an example.
Observes vs aggregated
----------------------
:func:`observes` and :func:`.aggregates.aggregated` can be used for similar
things. However performance wise you should take the following things into
consideration:
* :func:`observes` works always inside transaction and deals with objects. If
the relationship observer is observing has a large number of objects it's
better to use :func:`.aggregates.aggregated`.
* :func:`.aggregates.aggregated` always executes one additional query per
aggregate so in scenarios where the observed relationship has only a handful
of objects it's better to use :func:`observes` instead.
Example 1. Movie with many ratings
Let's say we have a Movie object with potentially thousands of ratings. In this
case we should always use :func:`.aggregates.aggregated` since iterating
through thousands of objects is slow and very memory consuming.
Example 2. Product with denormalized catalog name
Each product belongs to one catalog. Here it is natural to use :func:`observes`
for data denormalization.
Deeply nested observing
-----------------------
Consider the following model structure where Catalog has many Categories and
Category has many Products.
::
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
product_count = sa.Column(sa.Integer, default=0)
@observes('categories.products')
def product_observer(self, products):
self.product_count = len(products)
categories = sa.orm.relationship('Category', backref='catalog')
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
products = sa.orm.relationship('Product', backref='category')
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, primary_key=True)
price = sa.Column(sa.Numeric)
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
:func:`observes` is smart enough to:
* Notify catalog objects of any changes in associated Product objects
* Notify catalog objects of any changes in Category objects that affect
products (for example if Category gets deleted, or a new Category is added to
Catalog with any number of Products)
::
category = Category(
products=[Product(), Product()]
)
category2 = Category(
product=[Product()]
)
catalog = Catalog(
categories=[category, category2]
)
session.add(catalog)
session.commit()
catalog.product_count # 2
session.delete(category)
session.commit()
catalog.product_count # 1
Observing multiple columns
-----------------------
You can also observe multiple columns by specifying all the observable columns
in the decorator.
::
class Order(Base):
__tablename__ = 'order'
id = sa.Column(sa.Integer, primary_key=True)
unit_price = sa.Column(sa.Integer)
amount = sa.Column(sa.Integer)
total_price = sa.Column(sa.Integer)
@observes('amount', 'unit_price')
def total_price_observer(self, amount, unit_price):
self.total_price = amount * unit_price
"""
import itertools
from collections import defaultdict, namedtuple
from collections.abc import Iterable
import sqlalchemy as sa
from .functions import getdotattr, has_changes
from .path import AttrPath
from .utils import is_sequence
Callback = namedtuple('Callback', ['func', 'backref', 'fullpath'])
class PropertyObserver:
def __init__(self):
self.listener_args = [
(
sa.orm.Mapper,
'mapper_configured',
self.update_generator_registry
),
(
sa.orm.Mapper,
'after_configured',
self.gather_paths
),
(
sa.orm.session.Session,
'before_flush',
self.invoke_callbacks
)
]
self.callback_map = defaultdict(list)
# TODO: make the registry a WeakKey dict
self.generator_registry = defaultdict(list)
def remove_listeners(self):
for args in self.listener_args:
sa.event.remove(*args)
def register_listeners(self):
for args in self.listener_args:
if not sa.event.contains(*args):
sa.event.listen(*args)
def __repr__(self):
return '<PropertyObserver>'
def update_generator_registry(self, mapper, class_):
"""
Adds generator functions to generator_registry.
"""
for generator in class_.__dict__.values():
if hasattr(generator, '__observes__'):
self.generator_registry[class_].append(
generator
)
def gather_paths(self):
for class_, generators in self.generator_registry.items():
for callback in generators:
full_paths = []
for call_path in callback.__observes__:
full_paths.append(AttrPath(class_, call_path))
for path in full_paths:
self.callback_map[class_].append(
Callback(
func=callback,
backref=None,
fullpath=full_paths
)
)
for index in range(len(path)):
i = index + 1
prop = path[index].property
if isinstance(prop, sa.orm.RelationshipProperty):
prop_class = path[index].property.mapper.class_
self.callback_map[prop_class].append(
Callback(
func=callback,
backref=~ (path[:i]),
fullpath=full_paths
)
)
def gather_callback_args(self, obj, callbacks):
session = sa.orm.object_session(obj)
for callback in callbacks:
backref = callback.backref
root_objs = getdotattr(obj, backref) if backref else obj
if root_objs:
if not isinstance(root_objs, Iterable):
root_objs = [root_objs]
with session.no_autoflush:
for root_obj in root_objs:
if root_obj:
args = self.get_callback_args(root_obj, callback)
if args:
yield args
def get_callback_args(self, root_obj, callback):
session = sa.orm.object_session(root_obj)
objects = [getdotattr(
root_obj,
path,
lambda obj: obj not in session.deleted
) for path in callback.fullpath]
paths = [str(path) for path in callback.fullpath]
for path in paths:
if '.' in path or has_changes(root_obj, path):
return (
root_obj,
callback.func,
objects
)
def iterate_objects_and_callbacks(self, session):
objs = itertools.chain(session.new, session.dirty, session.deleted)
for obj in objs:
for class_, callbacks in self.callback_map.items():
if isinstance(obj, class_):
yield obj, callbacks
def invoke_callbacks(self, session, ctx, instances):
callback_args = defaultdict(lambda: defaultdict(set))
for obj, callbacks in self.iterate_objects_and_callbacks(session):
args = self.gather_callback_args(obj, callbacks)
for (root_obj, func, objects) in args:
if not callback_args[root_obj][func]:
callback_args[root_obj][func] = {}
for i, object_ in enumerate(objects):
if is_sequence(object_):
callback_args[root_obj][func][i] = (
callback_args[root_obj][func].get(i, set()) |
set(object_)
)
else:
callback_args[root_obj][func][i] = object_
for root_obj, callback_objs in callback_args.items():
for callback, objs in callback_objs.items():
callback(root_obj, *[objs[i] for i in range(len(objs))])
observer = PropertyObserver()
def observes(*paths, **observer_kw):
"""
Mark method as property observer for the given property path. Inside
transaction observer gathers all changes made in given property path and
feeds the changed objects to observer-marked method at the before flush
phase.
::
from sqlalchemy_utils import observes
class Catalog(Base):
__tablename__ = 'catalog'
id = sa.Column(sa.Integer, primary_key=True)
category_count = sa.Column(sa.Integer, default=0)
@observes('categories')
def category_observer(self, categories):
self.category_count = len(categories)
class Category(Base):
__tablename__ = 'category'
id = sa.Column(sa.Integer, primary_key=True)
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
catalog = Catalog(categories=[Category(), Category()])
session.add(catalog)
session.commit()
catalog.category_count # 2
.. versionadded: 0.28.0
:param *paths: One or more dot-notated property paths, eg.
'categories.products.price'
:param **observer: A dictionary where value for key 'observer' contains
:meth:`PropertyObserver` object
"""
observer_ = observer_kw.pop('observer', observer)
observer_.register_listeners()
def wraps(func):
def wrapper(self, *args, **kwargs):
return func(self, *args, **kwargs)
wrapper.__observes__ = paths
return wrapper
return wraps

View file

@ -0,0 +1,74 @@
import sqlalchemy as sa
def inspect_type(mixed):
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
return mixed.property.columns[0].type
elif isinstance(mixed, sa.orm.ColumnProperty):
return mixed.columns[0].type
elif isinstance(mixed, sa.Column):
return mixed.type
def is_case_insensitive(mixed):
try:
return isinstance(
inspect_type(mixed).comparator,
CaseInsensitiveComparator
)
except AttributeError:
try:
return issubclass(
inspect_type(mixed).comparator_factory,
CaseInsensitiveComparator
)
except AttributeError:
return False
class CaseInsensitiveComparator(sa.Unicode.Comparator):
@classmethod
def lowercase_arg(cls, func):
def operation(self, other, **kwargs):
operator = getattr(sa.Unicode.Comparator, func)
if other is None:
return operator(self, other, **kwargs)
if not is_case_insensitive(other):
other = sa.func.lower(other)
return operator(self, other, **kwargs)
return operation
def in_(self, other):
if isinstance(other, list) or isinstance(other, tuple):
other = map(sa.func.lower, other)
return sa.Unicode.Comparator.in_(self, other)
def notin_(self, other):
if isinstance(other, list) or isinstance(other, tuple):
other = map(sa.func.lower, other)
return sa.Unicode.Comparator.notin_(self, other)
string_operator_funcs = [
'__eq__',
'__ne__',
'__lt__',
'__le__',
'__gt__',
'__ge__',
'concat',
'contains',
'ilike',
'like',
'notlike',
'notilike',
'startswith',
'endswith',
]
for func in string_operator_funcs:
setattr(
CaseInsensitiveComparator,
func,
CaseInsensitiveComparator.lowercase_arg(func)
)

View file

@ -0,0 +1,152 @@
import sqlalchemy as sa
from sqlalchemy.orm.attributes import InstrumentedAttribute
from sqlalchemy.util.langhelpers import symbol
from .utils import str_coercible
@str_coercible
class Path:
def __init__(self, path, separator='.'):
if isinstance(path, Path):
self.path = path.path
else:
self.path = path
self.separator = separator
@property
def parts(self):
return self.path.split(self.separator)
def __iter__(self):
yield from self.parts
def __len__(self):
return len(self.parts)
def __repr__(self):
return f"{self.__class__.__name__}('{self.path}')"
def index(self, element):
return self.parts.index(element)
def __getitem__(self, slice):
result = self.parts[slice]
if isinstance(result, list):
return self.__class__(
self.separator.join(result),
separator=self.separator
)
return result
def __eq__(self, other):
return self.path == other.path and self.separator == other.separator
def __ne__(self, other):
return not (self == other)
def __unicode__(self):
return self.path
def get_attr(mixed, attr):
if isinstance(mixed, InstrumentedAttribute):
return getattr(
mixed.property.mapper.class_,
attr
)
else:
return getattr(mixed, attr)
@str_coercible
class AttrPath:
def __init__(self, class_, path):
self.class_ = class_
self.path = Path(path)
self.parts = []
last_attr = class_
for value in self.path:
last_attr = get_attr(last_attr, value)
self.parts.append(last_attr)
def __iter__(self):
yield from self.parts
def __invert__(self):
def get_backref(part):
prop = part.property
backref = prop.backref or prop.back_populates
if backref is None:
raise Exception(
"Invert failed because property '%s' of class "
"%s has no backref." % (
prop.key,
prop.parent.class_.__name__
)
)
if isinstance(backref, tuple):
return backref[0]
else:
return backref
if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
class_ = self.parts[-1].class_
else:
class_ = self.parts[-1].mapper.class_
return self.__class__(
class_,
'.'.join(map(get_backref, reversed(self.parts)))
)
def index(self, element):
for index, el in enumerate(self.parts):
if el is element:
return index
@property
def direction(self):
symbols = [part.property.direction for part in self.parts]
if symbol('MANYTOMANY') in symbols:
return symbol('MANYTOMANY')
elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
return symbol('MANYTOMANY')
return symbols[0]
@property
def uselist(self):
return any(part.property.uselist for part in self.parts)
def __getitem__(self, slice):
result = self.parts[slice]
if isinstance(result, list) and result:
if result[0] is self.parts[0]:
class_ = self.class_
else:
class_ = result[0].parent.class_
return self.__class__(
class_,
self.path[slice]
)
else:
return result
def __len__(self):
return len(self.path)
def __repr__(self):
return "{}({}, {!r})".format(
self.__class__.__name__,
self.class_.__name__,
self.path.path
)
def __eq__(self, other):
return self.path == other.path and self.class_ == other.class_
def __ne__(self, other):
return not (self == other)
def __unicode__(self):
return str(self.path)

View file

@ -0,0 +1,5 @@
from .country import Country # noqa
from .currency import Currency # noqa
from .ltree import Ltree # noqa
from .weekday import WeekDay # noqa
from .weekdays import WeekDays # noqa

View file

@ -0,0 +1,110 @@
from functools import total_ordering
from .. import i18n
from ..utils import str_coercible
@total_ordering
@str_coercible
class Country:
"""
Country class wraps a 2 to 3 letter country code. It provides various
convenience properties and methods.
::
from babel import Locale
from sqlalchemy_utils import Country, i18n
# First lets add a locale getter for testing purposes
i18n.get_locale = lambda: Locale('en')
Country('FI').name # Finland
Country('FI').code # FI
Country(Country('FI')).code # 'FI'
Country always validates the given code if you use at least the optional
dependency list 'babel', otherwise no validation are performed.
::
Country(None) # raises TypeError
Country('UnknownCode') # raises ValueError
Country supports equality operators.
::
Country('FI') == Country('FI')
Country('FI') != Country('US')
Country objects are hashable.
::
assert hash(Country('FI')) == hash('FI')
"""
def __init__(self, code_or_country):
if isinstance(code_or_country, Country):
self.code = code_or_country.code
elif isinstance(code_or_country, str):
self.validate(code_or_country)
self.code = code_or_country
else:
raise TypeError(
"Country() argument must be a string or a country, not '{}'"
.format(
type(code_or_country).__name__
)
)
@property
def name(self):
return i18n.get_locale().territories[self.code]
@classmethod
def validate(self, code):
try:
i18n.babel.Locale('en').territories[code]
except KeyError:
raise ValueError(
f'Could not convert string to country code: {code}'
)
except AttributeError:
# As babel is optional, we may raise an AttributeError accessing it
pass
def __eq__(self, other):
if isinstance(other, Country):
return self.code == other.code
elif isinstance(other, str):
return self.code == other
else:
return NotImplemented
def __hash__(self):
return hash(self.code)
def __ne__(self, other):
return not (self == other)
def __lt__(self, other):
if isinstance(other, Country):
return self.code < other.code
elif isinstance(other, str):
return self.code < other
return NotImplemented
def __repr__(self):
return f'{self.__class__.__name__}({self.code!r})'
def __unicode__(self):
return self.name

View file

@ -0,0 +1,109 @@
from .. import i18n, ImproperlyConfigured
from ..utils import str_coercible
@str_coercible
class Currency:
"""
Currency class wraps a 3-letter currency code. It provides various
convenience properties and methods.
::
from babel import Locale
from sqlalchemy_utils import Currency, i18n
# First lets add a locale getter for testing purposes
i18n.get_locale = lambda: Locale('en')
Currency('USD').name # US Dollar
Currency('USD').symbol # $
Currency(Currency('USD')).code # 'USD'
Currency always validates the given code if you use at least the optional
dependency list 'babel', otherwise no validation are performed.
::
Currency(None) # raises TypeError
Currency('UnknownCode') # raises ValueError
Currency supports equality operators.
::
Currency('USD') == Currency('USD')
Currency('USD') != Currency('EUR')
Currencies are hashable.
::
len(set([Currency('USD'), Currency('USD')])) # 1
"""
def __init__(self, code):
if i18n.babel is None:
raise ImproperlyConfigured(
"'babel' package is required in order to use Currency class."
)
if isinstance(code, Currency):
self.code = code
elif isinstance(code, str):
self.validate(code)
self.code = code
else:
raise TypeError(
'First argument given to Currency constructor should be '
'either an instance of Currency or valid three letter '
'currency code.'
)
@classmethod
def validate(self, code):
try:
i18n.babel.Locale('en').currencies[code]
except KeyError:
raise ValueError(f"'{code}' is not valid currency code.")
except AttributeError:
# As babel is optional, we may raise an AttributeError accessing it
pass
@property
def symbol(self):
return i18n.babel.numbers.get_currency_symbol(
self.code,
i18n.get_locale()
)
@property
def name(self):
return i18n.get_locale().currencies[self.code]
def __eq__(self, other):
if isinstance(other, Currency):
return self.code == other.code
elif isinstance(other, str):
return self.code == other
else:
return NotImplemented
def __ne__(self, other):
return not (self == other)
def __hash__(self):
return hash(self.code)
def __repr__(self):
return f'{self.__class__.__name__}({self.code!r})'
def __unicode__(self):
return self.code

View file

@ -0,0 +1,220 @@
import re
from ..utils import str_coercible
path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
@str_coercible
class Ltree:
"""
Ltree class wraps a valid string label path. It provides various
convenience properties and methods.
::
from sqlalchemy_utils import Ltree
Ltree('1.2.3').path # '1.2.3'
Ltree always validates the given path.
::
Ltree(None) # raises TypeError
Ltree('..') # raises ValueError
Validator is also available as class method.
::
Ltree.validate('1.2.3')
Ltree.validate(None) # raises TypeError
Ltree supports equality operators.
::
Ltree('Countries.Finland') == Ltree('Countries.Finland')
Ltree('Countries.Germany') != Ltree('Countries.Finland')
Ltree objects are hashable.
::
assert hash(Ltree('Finland')) == hash('Finland')
Ltree objects have length.
::
assert len(Ltree('1.2')) == 2
assert len(Ltree('some.one.some.where')) # 4
You can easily find subpath indexes.
::
assert Ltree('1.2.3').index('2.3') == 1
assert Ltree('1.2.3.4.5').index('3.4') == 2
Ltree objects can be sliced.
::
assert Ltree('1.2.3')[0:2] == Ltree('1.2')
assert Ltree('1.2.3')[1:] == Ltree('2.3')
Finding longest common ancestor.
::
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
Ltree objects can be concatenated.
::
assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
"""
def __init__(self, path_or_ltree):
if isinstance(path_or_ltree, Ltree):
self.path = path_or_ltree.path
elif isinstance(path_or_ltree, str):
self.validate(path_or_ltree)
self.path = path_or_ltree
else:
raise TypeError(
"Ltree() argument must be a string or an Ltree, not '{}'"
.format(
type(path_or_ltree).__name__
)
)
@classmethod
def validate(cls, path):
if path_matcher.match(path) is None:
raise ValueError(
f"'{path}' is not a valid ltree path."
)
def __len__(self):
return len(self.path.split('.'))
def index(self, other):
subpath = Ltree(other).path.split('.')
parts = self.path.split('.')
for index, _ in enumerate(parts):
if parts[index:len(subpath) + index] == subpath:
return index
raise ValueError('subpath not found')
def descendant_of(self, other):
"""
is left argument a descendant of right (or equal)?
::
assert Ltree('1.2.3.4.5').descendant_of('1.2.3')
"""
subpath = self[:len(Ltree(other))]
return subpath == other
def ancestor_of(self, other):
"""
is left argument an ancestor of right (or equal)?
::
assert Ltree('1.2.3').ancestor_of('1.2.3.4.5')
"""
subpath = Ltree(other)[:len(self)]
return subpath == self
def __getitem__(self, key):
if isinstance(key, int):
return Ltree(self.path.split('.')[key])
elif isinstance(key, slice):
return Ltree('.'.join(self.path.split('.')[key]))
raise TypeError(
'Ltree indices must be integers, not {}'.format(
key.__class__.__name__
)
)
def lca(self, *others):
"""
Lowest common ancestor, i.e., longest common prefix of paths
::
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
"""
other_parts = [Ltree(other).path.split('.') for other in others]
parts = self.path.split('.')
for index, element in enumerate(parts):
if any(
other[index] != element or
len(other) <= index + 1 or
len(parts) == index + 1
for other in other_parts
):
if index == 0:
return None
return Ltree('.'.join(parts[0:index]))
def __add__(self, other):
return Ltree(self.path + '.' + Ltree(other).path)
def __radd__(self, other):
return Ltree(other) + self
def __eq__(self, other):
if isinstance(other, Ltree):
return self.path == other.path
elif isinstance(other, str):
return self.path == other
else:
return NotImplemented
def __hash__(self):
return hash(self.path)
def __ne__(self, other):
return not (self == other)
def __repr__(self):
return f'{self.__class__.__name__}({self.path!r})'
def __unicode__(self):
return self.path
def __contains__(self, label):
return label in self.path.split('.')
def __gt__(self, other):
return self.path > other.path
def __lt__(self, other):
return self.path < other.path
def __ge__(self, other):
return self.path >= other.path
def __le__(self, other):
return self.path <= other.path

View file

@ -0,0 +1,54 @@
from functools import total_ordering
from .. import i18n
from ..utils import str_coercible
@str_coercible
@total_ordering
class WeekDay:
NUM_WEEK_DAYS = 7
def __init__(self, index):
if not (0 <= index < self.NUM_WEEK_DAYS):
raise ValueError(
"index must be between 0 and %d" % self.NUM_WEEK_DAYS
)
self.index = index
def __eq__(self, other):
if isinstance(other, WeekDay):
return self.index == other.index
else:
return NotImplemented
def __hash__(self):
return hash(self.index)
def __lt__(self, other):
return self.position < other.position
def __repr__(self):
return f'{self.__class__.__name__}({self.index!r})'
def __unicode__(self):
return self.name
def get_name(self, width='wide', context='format'):
names = i18n.babel.dates.get_day_names(
width,
context,
i18n.get_locale()
)
return names[self.index]
@property
def name(self):
return self.get_name()
@property
def position(self):
return (
self.index -
i18n.get_locale().first_week_day
) % self.NUM_WEEK_DAYS

View file

@ -0,0 +1,57 @@
from ..utils import str_coercible
from .weekday import WeekDay
@str_coercible
class WeekDays:
def __init__(self, bit_string_or_week_days):
if isinstance(bit_string_or_week_days, str):
self._days = set()
if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS:
raise ValueError(
'Bit string must be {} characters long.'.format(
WeekDay.NUM_WEEK_DAYS
)
)
for index, bit in enumerate(bit_string_or_week_days):
if bit not in '01':
raise ValueError(
'Bit string may only contain zeroes and ones.'
)
if bit == '1':
self._days.add(WeekDay(index))
elif isinstance(bit_string_or_week_days, WeekDays):
self._days = bit_string_or_week_days._days
else:
self._days = set(bit_string_or_week_days)
def __eq__(self, other):
if isinstance(other, WeekDays):
return self._days == other._days
elif isinstance(other, str):
return self.as_bit_string() == other
else:
return NotImplemented
def __iter__(self):
yield from sorted(self._days)
def __contains__(self, value):
return value in self._days
def __repr__(self):
return '{}({!r})'.format(
self.__class__.__name__,
self.as_bit_string()
)
def __unicode__(self):
return ', '.join(str(day) for day in self)
def as_bit_string(self):
return ''.join(
'1' if WeekDay(index) in self._days else '0'
for index in range(WeekDay.NUM_WEEK_DAYS)
)

View file

@ -0,0 +1,84 @@
import sqlalchemy as sa
class ProxyDict:
def __init__(self, parent, collection_name, mapping_attr):
self.parent = parent
self.collection_name = collection_name
self.child_class = mapping_attr.class_
self.key_name = mapping_attr.key
self.cache = {}
@property
def collection(self):
return getattr(self.parent, self.collection_name)
def keys(self):
descriptor = getattr(self.child_class, self.key_name)
return [x[0] for x in self.collection.values(descriptor)]
def __contains__(self, key):
if key in self.cache:
return self.cache[key] is not None
return self.fetch(key) is not None
def has_key(self, key):
return self.__contains__(key)
def fetch(self, key):
session = sa.orm.object_session(self.parent)
if session and sa.orm.util.has_identity(self.parent):
obj = self.collection.filter_by(**{self.key_name: key}).first()
self.cache[key] = obj
return obj
def create_new_instance(self, key):
value = self.child_class(**{self.key_name: key})
self.collection.append(value)
self.cache[key] = value
return value
def __getitem__(self, key):
if key in self.cache:
if self.cache[key] is not None:
return self.cache[key]
else:
value = self.fetch(key)
if value:
return value
return self.create_new_instance(key)
def __setitem__(self, key, value):
try:
existing = self[key]
self.collection.remove(existing)
except KeyError:
pass
self.collection.append(value)
self.cache[key] = value
def proxy_dict(parent, collection_name, mapping_attr):
try:
parent._proxy_dicts
except AttributeError:
parent._proxy_dicts = {}
try:
return parent._proxy_dicts[collection_name]
except KeyError:
parent._proxy_dicts[collection_name] = ProxyDict(
parent,
collection_name,
mapping_attr
)
return parent._proxy_dicts[collection_name]
def expire_proxy_dicts(target, context):
if hasattr(target, '_proxy_dicts'):
target._proxy_dicts = {}
sa.event.listen(sa.orm.Mapper, 'expire', expire_proxy_dicts)

View file

@ -0,0 +1,173 @@
"""
QueryChain is a wrapper for sequence of queries.
Features:
* Easy iteration for sequence of queries
* Limit, offset and count which are applied to all queries in the chain
* Smart __getitem__ support
Initialization
^^^^^^^^^^^^^^
QueryChain takes iterable of queries as first argument. Additionally limit and
offset parameters can be given
::
chain = QueryChain([session.query(User), session.query(Article)])
chain = QueryChain(
[session.query(User), session.query(Article)],
limit=4
)
Simple iteration
^^^^^^^^^^^^^^^^
::
chain = QueryChain([session.query(User), session.query(Article)])
for obj in chain:
print obj
Limit and offset
^^^^^^^^^^^^^^^^
Lets say you have 5 blog posts, 5 articles and 5 news items in your
database.
::
chain = QueryChain(
[
session.query(BlogPost),
session.query(Article),
session.query(NewsItem)
],
limit=5
)
list(chain) # all blog posts but not articles and news items
chain = chain.offset(4)
list(chain) # last blog post, and first four articles
Just like with original query object the limit and offset can be chained to
return a new QueryChain.
::
chain = chain.limit(5).offset(7)
Chain slicing
^^^^^^^^^^^^^
::
chain = QueryChain(
[
session.query(BlogPost),
session.query(Article),
session.query(NewsItem)
]
)
chain[3:6] # New QueryChain with offset=3 and limit=6
Count
^^^^^
Let's assume that there are five blog posts, five articles and five news
items in the database, and you have the following query chain::
chain = QueryChain(
[
session.query(BlogPost),
session.query(Article),
session.query(NewsItem)
]
)
You can then get the total number rows returned by the query chain
with :meth:`~QueryChain.count`::
>>> chain.count()
15
"""
from copy import copy
class QueryChain:
"""
QueryChain can be used as a wrapper for sequence of queries.
:param queries: A sequence of SQLAlchemy Query objects
:param limit: Similar to normal query limit this parameter can be used for
limiting the number of results for the whole query chain.
:param offset: Similar to normal query offset this parameter can be used
for offsetting the query chain as a whole.
.. versionadded: 0.26.0
"""
def __init__(self, queries, limit=None, offset=None):
self.queries = queries
self._limit = limit
self._offset = offset
def __iter__(self):
consumed = 0
skipped = 0
for query in self.queries:
query_copy = copy(query)
if self._limit:
query = query.limit(self._limit - consumed)
if self._offset:
query = query.offset(self._offset - skipped)
obj_count = 0
for obj in query:
consumed += 1
obj_count += 1
yield obj
if not obj_count:
skipped += query_copy.count()
else:
skipped += obj_count
def limit(self, value):
return self[:value]
def offset(self, value):
return self[value:]
def count(self):
"""
Return the total number of rows this QueryChain's queries would return.
"""
return sum(q.count() for q in self.queries)
def __getitem__(self, key):
if isinstance(key, slice):
return self.__class__(
queries=self.queries,
limit=key.stop if key.stop is not None else self._limit,
offset=key.start if key.start is not None else self._offset
)
else:
for obj in self[key:1]:
return obj
def __repr__(self):
return '<QueryChain at 0x%x>' % id(self)

View file

@ -0,0 +1,128 @@
import sqlalchemy as sa
import sqlalchemy.orm
from sqlalchemy.sql.util import ClauseAdapter
from ..compat import _select_args
from .chained_join import chained_join # noqa
def path_to_relationships(path, cls):
relationships = []
for path_name in path.split('.'):
rel = getattr(cls, path_name)
relationships.append(rel)
cls = rel.mapper.class_
return relationships
def adapt_expr(expr, *selectables):
for selectable in selectables:
expr = ClauseAdapter(selectable).traverse(expr)
return expr
def inverse_join(selectable, left_alias, right_alias, relationship):
if relationship.property.secondary is not None:
secondary_alias = sa.alias(relationship.property.secondary)
return selectable.join(
secondary_alias,
adapt_expr(
relationship.property.secondaryjoin,
sa.inspect(left_alias).selectable,
secondary_alias
)
).join(
right_alias,
adapt_expr(
relationship.property.primaryjoin,
sa.inspect(right_alias).selectable,
secondary_alias
)
)
else:
join = sa.orm.join(right_alias, left_alias, relationship)
onclause = join.onclause
return selectable.join(right_alias, onclause)
def relationship_to_correlation(relationship, alias):
if relationship.property.secondary is not None:
return adapt_expr(
relationship.property.primaryjoin,
alias,
)
else:
return sa.orm.join(
relationship.parent,
alias,
relationship
).onclause
def chained_inverse_join(relationships, leaf_model):
selectable = sa.inspect(leaf_model).selectable
aliases = [leaf_model]
for index, relationship in enumerate(relationships[1:]):
aliases.append(sa.orm.aliased(relationship.mapper.class_))
selectable = inverse_join(
selectable,
aliases[index],
aliases[index + 1],
relationships[index]
)
if relationships[-1].property.secondary is not None:
secondary_alias = sa.alias(relationships[-1].property.secondary)
selectable = selectable.join(
secondary_alias,
adapt_expr(
relationships[-1].property.secondaryjoin,
secondary_alias,
sa.inspect(aliases[-1]).selectable
)
)
aliases.append(secondary_alias)
return selectable, aliases
def select_correlated_expression(
root_model,
expr,
path,
leaf_model,
from_obj=None,
order_by=None,
correlate=True
):
relationships = list(reversed(path_to_relationships(path, root_model)))
query = sa.select(*_select_args(expr))
join_expr, aliases = chained_inverse_join(relationships, leaf_model)
if order_by:
query = query.order_by(
*[
adapt_expr(
o,
*(sa.inspect(alias).selectable for alias in aliases)
)
for o in order_by
]
)
condition = relationship_to_correlation(
relationships[-1],
aliases[-1]
)
if from_obj is not None:
condition = adapt_expr(condition, from_obj)
query = query.select_from(join_expr.selectable)
if correlate:
query = query.correlate(
from_obj if from_obj is not None else root_model
)
return query.where(condition)

View file

@ -0,0 +1,31 @@
def chained_join(*relationships):
"""
Return a chained Join object for given relationships.
"""
property_ = relationships[0].property
if property_.secondary is not None:
from_ = property_.secondary.join(
property_.mapper.class_.__table__,
property_.secondaryjoin
)
else:
from_ = property_.mapper.class_.__table__
for relationship in relationships[1:]:
prop = relationship.property
if prop.secondary is not None:
from_ = from_.join(
prop.secondary,
prop.primaryjoin
)
from_ = from_.join(
prop.mapper.class_,
prop.secondaryjoin
)
else:
from_ = from_.join(
prop.mapper.class_,
prop.primaryjoin
)
return from_

View file

@ -0,0 +1,63 @@
from functools import wraps
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
from .arrow import ArrowType # noqa
from .choice import Choice, ChoiceType # noqa
from .color import ColorType # noqa
from .country import CountryType # noqa
from .currency import CurrencyType # noqa
from .email import EmailType # noqa
from .encrypted.encrypted_type import ( # noqa
EncryptedType,
StringEncryptedType
)
from .enriched_datetime.enriched_date_type import EnrichedDateType # noqa
from .ip_address import IPAddressType # noqa
from .json import JSONType # noqa
from .locale import LocaleType # noqa
from .ltree import LtreeType # noqa
from .password import Password, PasswordType # noqa
from .pg_composite import ( # noqa
CompositeType,
register_composites,
remove_composite_listeners
)
from .phone_number import ( # noqa
PhoneNumber,
PhoneNumberParseException,
PhoneNumberType
)
from .range import ( # noqa
DateRangeType,
DateTimeRangeType,
Int8RangeType,
IntRangeType,
NumericRangeType
)
from .scalar_list import ScalarListException, ScalarListType # noqa
from .timezone import TimezoneType # noqa
from .ts_vector import TSVectorType # noqa
from .url import URLType # noqa
from .uuid import UUIDType # noqa
from .weekdays import WeekDaysType # noqa
from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType # noqa isort:skip
class InstrumentedList(_InstrumentedList):
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
additional functionality."""
def any(self, attr):
return any(getattr(item, attr) for item in self)
def all(self, attr):
return all(getattr(item, attr) for item in self)
def instrumented_list(f):
@wraps(f)
def wrapper(*args, **kwargs):
return InstrumentedList([item for item in f(*args, **kwargs)])
return wrapper

View file

@ -0,0 +1,62 @@
from ..exceptions import ImproperlyConfigured
from .enriched_datetime import ArrowDateTime
from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType
arrow = None
try:
import arrow
except ImportError:
pass
class ArrowType(EnrichedDateTimeType):
"""
ArrowType provides way of saving Arrow_ objects into database. It
automatically changes Arrow_ objects to datetime objects on the way in and
datetime objects back to Arrow_ objects on the way out (when querying
database). ArrowType needs Arrow_ library installed.
.. _Arrow: https://github.com/arrow-py/arrow
::
from datetime import datetime
from sqlalchemy_utils import ArrowType
import arrow
class Article(Base):
__tablename__ = 'article'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
created_at = sa.Column(ArrowType)
article = Article(created_at=arrow.utcnow())
As you may expect all the arrow goodies come available:
::
article.created_at = article.created_at.replace(hours=-1)
article.created_at.humanize()
# 'an hour ago'
"""
cache_ok = True
def __init__(self, *args, **kwargs):
if not arrow:
raise ImproperlyConfigured(
"'arrow' package is required to use 'ArrowType'"
)
super().__init__(
datetime_processor=ArrowDateTime,
*args,
**kwargs
)

View file

@ -0,0 +1,24 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import BIT
class BitType(sa.types.TypeDecorator):
"""
BitType offers way of saving BITs into database.
"""
impl = sa.types.BINARY
cache_ok = True
def __init__(self, length=1, **kwargs):
self.length = length
sa.types.TypeDecorator.__init__(self, **kwargs)
def load_dialect_impl(self, dialect):
# Use the native BIT type for drivers that has it.
if dialect.name == 'postgresql':
return dialect.type_descriptor(BIT(self.length))
elif dialect.name == 'sqlite':
return dialect.type_descriptor(sa.String(self.length))
else:
return dialect.type_descriptor(type(self.impl)(self.length))

View file

@ -0,0 +1,225 @@
from enum import Enum
from sqlalchemy import types
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
class Choice:
def __init__(self, code, value):
self.code = code
self.value = value
def __eq__(self, other):
if isinstance(other, Choice):
return self.code == other.code
return other == self.code
def __hash__(self):
return hash(self.code)
def __ne__(self, other):
return not (self == other)
def __str__(self):
return str(self.value)
def __repr__(self):
return 'Choice(code={code}, value={value})'.format(
code=self.code,
value=self.value
)
class ChoiceType(ScalarCoercible, types.TypeDecorator):
"""
ChoiceType offers way of having fixed set of choices for given column. It
could work with a list of tuple (a collection of key-value pairs), or
integrate with :mod:`enum` in the standard library of Python 3.
Columns with ChoiceTypes are automatically coerced to Choice objects while
a list of tuple been passed to the constructor. If a subclass of
:class:`enum.Enum` is passed, columns will be coerced to :class:`enum.Enum`
objects instead.
::
class User(Base):
TYPES = [
('admin', 'Admin'),
('regular-user', 'Regular user')
]
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(ChoiceType(TYPES))
user = User(type='admin')
user.type # Choice(code='admin', value='Admin')
Or::
import enum
class UserType(enum.Enum):
admin = 1
regular = 2
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(ChoiceType(UserType, impl=sa.Integer()))
user = User(type=1)
user.type # <UserType.admin: 1>
ChoiceType is very useful when the rendered values change based on user's
locale:
::
from babel import lazy_gettext as _
class User(Base):
TYPES = [
('admin', _('Admin')),
('regular-user', _('Regular user'))
]
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(ChoiceType(TYPES))
user = User(type='admin')
user.type # Choice(code='admin', value='Admin')
print user.type # 'Admin'
Or::
from enum import Enum
from babel import lazy_gettext as _
class UserType(Enum):
admin = 1
regular = 2
UserType.admin.label = _('Admin')
UserType.regular.label = _('Regular user')
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
type = sa.Column(ChoiceType(UserType, impl=sa.Integer()))
user = User(type=UserType.admin)
user.type # <UserType.admin: 1>
print user.type.label # 'Admin'
"""
impl = types.Unicode(255)
cache_ok = True
def __init__(self, choices, impl=None):
self.choices = tuple(choices) if isinstance(choices, list) else choices
if (
Enum is not None and
isinstance(choices, type) and
issubclass(choices, Enum)
):
self.type_impl = EnumTypeImpl(enum_class=choices)
else:
self.type_impl = ChoiceTypeImpl(choices=choices)
if impl:
self.impl = impl
@property
def python_type(self):
return self.impl.python_type
def _coerce(self, value):
return self.type_impl._coerce(value)
def process_bind_param(self, value, dialect):
return self.type_impl.process_bind_param(value, dialect)
def process_result_value(self, value, dialect):
return self.type_impl.process_result_value(value, dialect)
class ChoiceTypeImpl:
"""The implementation for the ``Choice`` usage."""
def __init__(self, choices):
if not choices:
raise ImproperlyConfigured(
'ChoiceType needs list of choices defined.'
)
self.choices_dict = dict(choices)
def _coerce(self, value):
if value is None:
return value
if isinstance(value, Choice):
return value
return Choice(value, self.choices_dict[value])
def process_bind_param(self, value, dialect):
if value and isinstance(value, Choice):
return value.code
return value
def process_result_value(self, value, dialect):
if value:
return Choice(value, self.choices_dict[value])
return value
class EnumTypeImpl:
"""The implementation for the ``Enum`` usage."""
def __init__(self, enum_class):
if Enum is None:
raise ImproperlyConfigured(
"'enum34' package is required to use 'EnumType' in Python "
"< 3.4"
)
if not issubclass(enum_class, Enum):
raise ImproperlyConfigured(
"EnumType needs a class of enum defined."
)
self.enum_class = enum_class
def _coerce(self, value):
if value is None:
return None
return self.enum_class(value)
def process_bind_param(self, value, dialect):
if value is None:
return None
return self.enum_class(value).value
def process_result_value(self, value, dialect):
return self._coerce(value)

View file

@ -0,0 +1,79 @@
from sqlalchemy import types
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
colour = None
try:
import colour
python_colour_type = colour.Color
except (ImportError, AttributeError):
python_colour_type = None
class ColorType(ScalarCoercible, types.TypeDecorator):
"""
ColorType provides a way for saving Color (from colour_ package) objects
into database. ColorType saves Color objects as strings on the way in and
converts them back to objects when querying the database.
::
from colour import Color
from sqlalchemy_utils import ColorType
class Document(Base):
__tablename__ = 'document'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(50))
background_color = sa.Column(ColorType)
document = Document()
document.background_color = Color('#F5F5F5')
session.commit()
Querying the database returns Color objects:
::
document = session.query(Document).first()
document.background_color.hex
# '#f5f5f5'
.. _colour: https://github.com/vaab/colour
"""
STORE_FORMAT = 'hex'
impl = types.Unicode(20)
python_type = python_colour_type
cache_ok = True
def __init__(self, max_length=20, *args, **kwargs):
# Fail if colour is not found.
if colour is None:
raise ImproperlyConfigured(
"'colour' package is required to use 'ColorType'"
)
super().__init__(*args, **kwargs)
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
if value and isinstance(value, colour.Color):
return str(getattr(value, self.STORE_FORMAT))
return value
def process_result_value(self, value, dialect):
if value:
return colour.Color(value)
return value
def _coerce(self, value):
if value is not None and not isinstance(value, colour.Color):
return colour.Color(value)
return value

View file

@ -0,0 +1,64 @@
from sqlalchemy import types
from ..primitives import Country
from .scalar_coercible import ScalarCoercible
class CountryType(ScalarCoercible, types.TypeDecorator):
"""
Changes :class:`.Country` objects to a string representation on the way in
and changes them back to :class:`.Country objects on the way out.
In order to use CountryType you need to install Babel_ first.
.. _Babel: https://babel.pocoo.org/
::
from sqlalchemy_utils import CountryType, Country
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
country = sa.Column(CountryType)
user = User()
user.country = Country('FI')
session.add(user)
session.commit()
user.country # Country('FI')
user.country.name # Finland
print user.country # Finland
CountryType is scalar coercible::
user.country = 'US'
user.country # Country('US')
"""
impl = types.String(2)
python_type = Country
cache_ok = True
def process_bind_param(self, value, dialect):
if isinstance(value, Country):
return value.code
if isinstance(value, str):
return value
def process_result_value(self, value, dialect):
if value is not None:
return Country(value)
def _coerce(self, value):
if value is not None and not isinstance(value, Country):
return Country(value)
return value

View file

@ -0,0 +1,74 @@
from sqlalchemy import types
from .. import i18n, ImproperlyConfigured
from ..primitives import Currency
from .scalar_coercible import ScalarCoercible
class CurrencyType(ScalarCoercible, types.TypeDecorator):
"""
Changes :class:`.Currency` objects to a string representation on the way in
and changes them back to :class:`.Currency` objects on the way out.
In order to use CurrencyType you need to install Babel_ first.
.. _Babel: https://babel.pocoo.org/
::
from sqlalchemy_utils import CurrencyType, Currency
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
currency = sa.Column(CurrencyType)
user = User()
user.currency = Currency('USD')
session.add(user)
session.commit()
user.currency # Currency('USD')
user.currency.name # US Dollar
str(user.currency) # US Dollar
user.currency.symbol # $
CurrencyType is scalar coercible::
user.currency = 'US'
user.currency # Currency('US')
"""
impl = types.String(3)
python_type = Currency
cache_ok = True
def __init__(self, *args, **kwargs):
if i18n.babel is None:
raise ImproperlyConfigured(
"'babel' package is required in order to use CurrencyType."
)
super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect):
if isinstance(value, Currency):
return value.code
elif isinstance(value, str):
return value
def process_result_value(self, value, dialect):
if value is not None:
return Currency(value)
def _coerce(self, value):
if value is not None and not isinstance(value, Currency):
return Currency(value)
return value

View file

@ -0,0 +1,48 @@
import sqlalchemy as sa
from ..operators import CaseInsensitiveComparator
class EmailType(sa.types.TypeDecorator):
"""
Provides a way for storing emails in a lower case.
Example::
from sqlalchemy_utils import EmailType
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.Unicode(255))
email = sa.Column(EmailType)
user = User()
user.email = 'John.Smith@foo.com'
user.name = 'John Smith'
session.add(user)
session.commit()
# Notice - email in filter() is lowercase.
user = (session.query(User)
.filter(User.email == 'john.smith@foo.com')
.one())
assert user.name == 'John Smith'
"""
impl = sa.Unicode
comparator_factory = CaseInsensitiveComparator
cache_ok = True
def __init__(self, length=255, *args, **kwargs):
super().__init__(length=length, *args, **kwargs)
def process_bind_param(self, value, dialect):
if value is not None:
return value.lower()
return value
@property
def python_type(self):
return self.impl.type.python_type

View file

@ -0,0 +1 @@
# Module for encrypted type

View file

@ -0,0 +1,508 @@
import base64
import datetime
import json
import os
import warnings
from sqlalchemy.types import LargeBinary, String, TypeDecorator
from sqlalchemy_utils.exceptions import ImproperlyConfigured
from sqlalchemy_utils.types.encrypted.padding import PADDING_MECHANISM
from sqlalchemy_utils.types.json import JSONType
from sqlalchemy_utils.types.scalar_coercible import ScalarCoercible
cryptography = None
try:
import cryptography
from cryptography.exceptions import InvalidTag
from cryptography.fernet import Fernet
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.ciphers import (
algorithms,
Cipher,
modes
)
except ImportError:
pass
dateutil = None
try:
import dateutil
from dateutil.parser import parse as datetime_parse
except ImportError:
pass
class InvalidCiphertextError(Exception):
pass
class EncryptionDecryptionBaseEngine:
"""A base encryption and decryption engine.
This class must be sub-classed in order to create
new engines.
"""
def _update_key(self, key):
if isinstance(key, str):
key = key.encode()
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
digest.update(key)
engine_key = digest.finalize()
self._initialize_engine(engine_key)
def encrypt(self, value):
raise NotImplementedError('Subclasses must implement this!')
def decrypt(self, value):
raise NotImplementedError('Subclasses must implement this!')
class AesEngine(EncryptionDecryptionBaseEngine):
"""Provide AES encryption and decryption methods.
You may also consider using the AesGcmEngine instead -- that may be
a better fit for some cases.
You should NOT use the AesGcmEngine if you want to be able to search
for a row based on the value of an encrypted column. Use AesEngine
instead, since that allows you to perform such searches.
If you don't need to search by the value of an encypted column, the
AesGcmEngine provides better security.
"""
BLOCK_SIZE = 16
def _initialize_engine(self, parent_class_key):
self.secret_key = parent_class_key
self.iv = self.secret_key[:16]
self.cipher = Cipher(
algorithms.AES(self.secret_key),
modes.CBC(self.iv),
backend=default_backend()
)
def _set_padding_mechanism(self, padding_mechanism=None):
"""Set the padding mechanism."""
if isinstance(padding_mechanism, str):
if padding_mechanism not in PADDING_MECHANISM.keys():
raise ImproperlyConfigured(
"There is not padding mechanism with name {}".format(
padding_mechanism
)
)
if padding_mechanism is None:
padding_mechanism = 'naive'
padding_class = PADDING_MECHANISM[padding_mechanism]
self.padding_engine = padding_class(self.BLOCK_SIZE)
def encrypt(self, value):
if not isinstance(value, str):
value = repr(value)
if isinstance(value, str):
value = str(value)
value = value.encode()
value = self.padding_engine.pad(value)
encryptor = self.cipher.encryptor()
encrypted = encryptor.update(value) + encryptor.finalize()
encrypted = base64.b64encode(encrypted)
return encrypted.decode('utf-8')
def decrypt(self, value):
if isinstance(value, str):
value = str(value)
decryptor = self.cipher.decryptor()
decrypted = base64.b64decode(value)
decrypted = decryptor.update(decrypted) + decryptor.finalize()
decrypted = self.padding_engine.unpad(decrypted)
if not isinstance(decrypted, str):
try:
decrypted = decrypted.decode('utf-8')
except UnicodeDecodeError:
raise ValueError('Invalid decryption key')
return decrypted
class AesGcmEngine(EncryptionDecryptionBaseEngine):
"""Provide AES/GCM encryption and decryption methods.
You may also consider using the AesEngine instead -- that may be
a better fit for some cases.
You should NOT use this AesGcmEngine if you want to be able to search
for a row based on the value of an encrypted column. Use AesEngine
instead, since that allows you to perform such searches.
If you don't need to search by the value of an encypted column, the
AesGcmEngine provides better security.
"""
BLOCK_SIZE = 16
IV_BYTES_NEEDED = 12
TAG_SIZE_BYTES = BLOCK_SIZE
def _initialize_engine(self, parent_class_key):
self.secret_key = parent_class_key
def encrypt(self, value):
if not isinstance(value, str):
value = repr(value)
if isinstance(value, str):
value = str(value)
value = value.encode()
iv = os.urandom(self.IV_BYTES_NEEDED)
cipher = Cipher(
algorithms.AES(self.secret_key),
modes.GCM(iv),
backend=default_backend()
)
encryptor = cipher.encryptor()
encrypted = encryptor.update(value) + encryptor.finalize()
assert len(encryptor.tag) == self.TAG_SIZE_BYTES
encrypted = base64.b64encode(iv + encryptor.tag + encrypted)
return encrypted.decode('utf-8')
def decrypt(self, value):
if isinstance(value, str):
value = str(value)
decrypted = base64.b64decode(value)
if len(decrypted) < self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES:
raise InvalidCiphertextError()
iv = decrypted[:self.IV_BYTES_NEEDED]
tag = decrypted[self.IV_BYTES_NEEDED:
self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES]
decrypted = decrypted[self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES:]
cipher = Cipher(
algorithms.AES(self.secret_key),
modes.GCM(iv, tag),
backend=default_backend()
)
decryptor = cipher.decryptor()
try:
decrypted = decryptor.update(decrypted) + decryptor.finalize()
except InvalidTag:
raise InvalidCiphertextError()
if not isinstance(decrypted, str):
try:
decrypted = decrypted.decode('utf-8')
except UnicodeDecodeError:
raise InvalidCiphertextError()
return decrypted
class FernetEngine(EncryptionDecryptionBaseEngine):
"""Provide Fernet encryption and decryption methods."""
def _initialize_engine(self, parent_class_key):
self.secret_key = base64.urlsafe_b64encode(parent_class_key)
self.fernet = Fernet(self.secret_key)
def encrypt(self, value):
if not isinstance(value, str):
value = repr(value)
if isinstance(value, str):
value = str(value)
value = value.encode()
encrypted = self.fernet.encrypt(value)
return encrypted.decode('utf-8')
def decrypt(self, value):
if isinstance(value, str):
value = str(value)
decrypted = self.fernet.decrypt(value.encode())
if not isinstance(decrypted, str):
decrypted = decrypted.decode('utf-8')
return decrypted
class StringEncryptedType(TypeDecorator, ScalarCoercible):
"""
EncryptedType provides a way to encrypt and decrypt values,
to and from databases, that their type is a basic SQLAlchemy type.
For example Unicode, String or even Boolean.
On the way in, the value is encrypted and on the way out the stored value
is decrypted.
EncryptedType needs Cryptography_ library in order to work.
When declaring a column which will be of type EncryptedType
it is better to be as precise as possible and follow the pattern
below.
.. _Cryptography: https://cryptography.io/en/latest/
::
a_column = sa.Column(EncryptedType(sa.Unicode,
secret_key,
FernetEngine))
another_column = sa.Column(EncryptedType(sa.Unicode,
secret_key,
AesEngine,
'pkcs5'))
A more complete example is given below.
::
import sqlalchemy as sa
from sqlalchemy import create_engine
try:
from sqlalchemy.orm import declarative_base
except ImportError:
# sqlalchemy 1.3
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import EncryptedType
from sqlalchemy_utils.types.encrypted.encrypted_type import AesEngine
secret_key = 'secretkey1234'
# setup
engine = create_engine('sqlite:///:memory:')
connection = engine.connect()
Base = declarative_base()
class User(Base):
__tablename__ = "user"
id = sa.Column(sa.Integer, primary_key=True)
username = sa.Column(EncryptedType(sa.Unicode,
secret_key,
AesEngine,
'pkcs5'))
access_token = sa.Column(EncryptedType(sa.String,
secret_key,
AesEngine,
'pkcs5'))
is_active = sa.Column(EncryptedType(sa.Boolean,
secret_key,
AesEngine,
'zeroes'))
number_of_accounts = sa.Column(EncryptedType(sa.Integer,
secret_key,
AesEngine,
'oneandzeroes'))
sa.orm.configure_mappers()
Base.metadata.create_all(connection)
# create a configured "Session" class
Session = sessionmaker(bind=connection)
# create a Session
session = Session()
# example
user_name = 'secret_user'
test_token = 'atesttoken'
active = True
num_of_accounts = 2
user = User(username=user_name, access_token=test_token,
is_active=active, number_of_accounts=num_of_accounts)
session.add(user)
session.commit()
user_id = user.id
session.expunge_all()
user_instance = session.query(User).get(user_id)
print('id: {}'.format(user_instance.id))
print('username: {}'.format(user_instance.username))
print('token: {}'.format(user_instance.access_token))
print('active: {}'.format(user_instance.is_active))
print('accounts: {}'.format(user_instance.number_of_accounts))
# teardown
session.close_all()
Base.metadata.drop_all(connection)
connection.close()
engine.dispose()
The key parameter accepts a callable to allow for the key to change
per-row instead of being fixed for the whole table.
::
def get_key():
return 'dynamic-key'
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
username = sa.Column(EncryptedType(
sa.Unicode, get_key))
"""
impl = String
cache_ok = True
def __init__(
self,
type_in=None,
key=None,
engine=None,
padding=None,
**kwargs
):
"""Initialization."""
if not cryptography:
raise ImproperlyConfigured(
"'cryptography' is required to use EncryptedType"
)
super().__init__(**kwargs)
# set the underlying type
if type_in is None:
type_in = String()
elif isinstance(type_in, type):
type_in = type_in()
self.underlying_type = type_in
self._key = key
if not engine:
engine = AesEngine
self.engine = engine()
if isinstance(self.engine, AesEngine):
self.engine._set_padding_mechanism(padding)
@property
def key(self):
return self._key
@key.setter
def key(self, value):
self._key = value
def _update_key(self):
key = self._key() if callable(self._key) else self._key
self.engine._update_key(key)
def process_bind_param(self, value, dialect):
"""Encrypt a value on the way in."""
if value is not None:
self._update_key()
try:
value = self.underlying_type.process_bind_param(
value, dialect
)
except AttributeError:
# Doesn't have 'process_bind_param'
# Handle 'boolean' and 'dates'
type_ = self.underlying_type.python_type
if issubclass(type_, bool):
value = 'true' if value else 'false'
elif issubclass(type_, (datetime.date, datetime.time)):
value = value.isoformat()
elif issubclass(type_, JSONType):
value = json.dumps(value)
return self.engine.encrypt(value)
def process_result_value(self, value, dialect):
"""Decrypt value on the way out."""
if value is not None:
self._update_key()
decrypted_value = self.engine.decrypt(value)
try:
return self.underlying_type.process_result_value(
decrypted_value, dialect
)
except AttributeError:
# Doesn't have 'process_result_value'
# Handle 'boolean' and 'dates'
type_ = self.underlying_type.python_type
date_types = [datetime.datetime, datetime.time, datetime.date]
if issubclass(type_, bool):
return decrypted_value == 'true'
elif type_ in date_types:
return DatetimeHandler.process_value(
decrypted_value, type_
)
elif issubclass(type_, JSONType):
return json.loads(decrypted_value)
# Handle all others
return self.underlying_type.python_type(decrypted_value)
def _coerce(self, value):
if isinstance(self.underlying_type, ScalarCoercible):
return self.underlying_type._coerce(value)
return value
class EncryptedType(StringEncryptedType):
impl = LargeBinary
def __init__(self, *args, **kwargs):
warnings.warn(
"The 'EncryptedType' class will change implementation from "
"'LargeBinary' to 'String' in a future version. Use "
"'StringEncryptedType' to use the 'String' implementation.",
DeprecationWarning, stacklevel=2)
super().__init__(*args, **kwargs)
def process_bind_param(self, value, dialect):
value = super().process_bind_param(value=value, dialect=dialect)
if isinstance(value, str):
value = value.encode()
return value
def process_result_value(self, value, dialect):
if isinstance(value, bytes):
value = value.decode()
value = super().process_result_value(value=value, dialect=dialect)
return value
class DatetimeHandler:
"""
DatetimeHandler is responsible for parsing strings and
returning the appropriate date, datetime or time objects.
"""
@classmethod
def process_value(cls, value, python_type):
"""
process_value returns a datetime, date
or time object according to a given string
value and a python type.
"""
if not dateutil:
raise ImproperlyConfigured(
"'python-dateutil' is required to process datetimes"
)
return_value = datetime_parse(value)
if issubclass(python_type, datetime.datetime):
return return_value
elif issubclass(python_type, datetime.time):
return return_value.time()
elif issubclass(python_type, datetime.date):
return return_value.date()

View file

@ -0,0 +1,142 @@
class InvalidPaddingError(Exception):
pass
class Padding:
"""Base class for padding and unpadding."""
def __init__(self, block_size):
self.block_size = block_size
def pad(self, value):
raise NotImplementedError('Subclasses must implement this!')
def unpad(self, value):
raise NotImplementedError('Subclasses must implement this!')
class PKCS5Padding(Padding):
"""Provide PKCS5 padding and unpadding."""
def pad(self, value):
if not isinstance(value, bytes):
value = value.encode()
padding_length = (self.block_size - len(value) % self.block_size)
padding_sequence = padding_length * bytes((padding_length,))
value_with_padding = value + padding_sequence
return value_with_padding
def unpad(self, value):
# Perform some input validations.
# In case of error, we throw a generic InvalidPaddingError()
if not value or len(value) < self.block_size:
# PKCS5 padded output will always be at least 1 block size
raise InvalidPaddingError()
if len(value) % self.block_size != 0:
# PKCS5 padded output will be a multiple of the block size
raise InvalidPaddingError()
if isinstance(value, bytes):
padding_length = value[-1]
if isinstance(value, str):
padding_length = ord(value[-1])
if padding_length == 0 or padding_length > self.block_size:
raise InvalidPaddingError()
def convert_byte_or_char_to_number(x):
return ord(x) if isinstance(x, str) else x
if any([padding_length != convert_byte_or_char_to_number(x)
for x in value[-padding_length:]]):
raise InvalidPaddingError()
value_without_padding = value[0:-padding_length]
return value_without_padding
class OneAndZeroesPadding(Padding):
"""Provide the one and zeroes padding and unpadding.
This mechanism pads with 0x80 followed by zero bytes.
For unpadding it strips off all trailing zero bytes and the 0x80 byte.
"""
BYTE_80 = 0x80
BYTE_00 = 0x00
def pad(self, value):
if not isinstance(value, bytes):
value = value.encode()
padding_length = (self.block_size - len(value) % self.block_size)
one_part_bytes = bytes((self.BYTE_80,))
zeroes_part_bytes = (padding_length - 1) * bytes((self.BYTE_00,))
padding_sequence = one_part_bytes + zeroes_part_bytes
value_with_padding = value + padding_sequence
return value_with_padding
def unpad(self, value):
value_without_padding = value.rstrip(bytes((self.BYTE_00,)))
value_without_padding = value_without_padding.rstrip(
bytes((self.BYTE_80,)))
return value_without_padding
class ZeroesPadding(Padding):
"""Provide zeroes padding and unpadding.
This mechanism pads with 0x00 except the last byte equals
to the padding length. For unpadding it reads the last byte
and strips off that many bytes.
"""
BYTE_00 = 0x00
def pad(self, value):
if not isinstance(value, bytes):
value = value.encode()
padding_length = (self.block_size - len(value) % self.block_size)
zeroes_part_bytes = (padding_length - 1) * bytes((self.BYTE_00,))
last_part_bytes = bytes((padding_length,))
padding_sequence = zeroes_part_bytes + last_part_bytes
value_with_padding = value + padding_sequence
return value_with_padding
def unpad(self, value):
if isinstance(value, bytes):
padding_length = value[-1]
if isinstance(value, str):
padding_length = ord(value[-1])
value_without_padding = value[0:-padding_length]
return value_without_padding
class NaivePadding(Padding):
"""Naive padding and unpadding using '*'.
The class is provided only for backwards compatibility.
"""
CHARACTER = b'*'
def pad(self, value):
num_of_bytes = (self.block_size - len(value) % self.block_size)
value_with_padding = value + num_of_bytes * self.CHARACTER
return value_with_padding
def unpad(self, value):
value_without_padding = value.rstrip(self.CHARACTER)
return value_without_padding
PADDING_MECHANISM = {
'pkcs5': PKCS5Padding,
'oneandzeroes': OneAndZeroesPadding,
'zeroes': ZeroesPadding,
'naive': NaivePadding
}

View file

@ -0,0 +1,4 @@
# Module for enriched date, datetime type
from .arrow_datetime import ArrowDateTime # noqa
from .pendulum_date import PendulumDate # noqa
from .pendulum_datetime import PendulumDateTime # noqa

View file

@ -0,0 +1,39 @@
from collections.abc import Iterable
from datetime import datetime
from ...exceptions import ImproperlyConfigured
arrow = None
try:
import arrow
except ImportError:
pass
class ArrowDateTime:
def __init__(self):
if not arrow:
raise ImproperlyConfigured(
"'arrow' package is required to use 'ArrowDateTime'"
)
def _coerce(self, impl, value):
if isinstance(value, str):
value = arrow.get(value)
elif isinstance(value, Iterable):
value = arrow.get(*value)
elif isinstance(value, datetime):
value = arrow.get(value)
return value
def process_bind_param(self, impl, value, dialect):
if value:
utc_val = self._coerce(impl, value).to('UTC')
return utc_val.datetime\
if impl.timezone else utc_val.naive
return value
def process_result_value(self, impl, value, dialect):
if value:
return arrow.get(value)
return value

View file

@ -0,0 +1,50 @@
from sqlalchemy import types
from ..scalar_coercible import ScalarCoercible
from .pendulum_date import PendulumDate
class EnrichedDateType(types.TypeDecorator, ScalarCoercible):
"""
Supported for pendulum only.
Example::
from sqlalchemy_utils import EnrichedDateType
import pendulum
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
birthday = sa.Column(EnrichedDateType(type="pendulum"))
user = User()
user.birthday = pendulum.datetime(year=1995, month=7, day=11)
session.add(user)
session.commit()
"""
impl = types.Date
cache_ok = True
def __init__(self, date_processor=PendulumDate, *args, **kwargs):
super().__init__(*args, **kwargs)
self.date_object = date_processor()
def _coerce(self, value):
return self.date_object._coerce(self.impl, value)
def process_bind_param(self, value, dialect):
return self.date_object.process_bind_param(self.impl, value, dialect)
def process_result_value(self, value, dialect):
return self.date_object.process_result_value(self.impl, value, dialect)
def process_literal_param(self, value, dialect):
return value
@property
def python_type(self):
return self.impl.type.python_type

View file

@ -0,0 +1,51 @@
from sqlalchemy import types
from ..scalar_coercible import ScalarCoercible
from .pendulum_datetime import PendulumDateTime
class EnrichedDateTimeType(types.TypeDecorator, ScalarCoercible):
"""
Supported for arrow and pendulum.
Example::
from sqlalchemy_utils import EnrichedDateTimeType
import pendulum
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
created_at = sa.Column(EnrichedDateTimeType(type="pendulum"))
# created_at = sa.Column(EnrichedDateTimeType(type="arrow"))
user = User()
user.created_at = pendulum.now()
session.add(user)
session.commit()
"""
impl = types.DateTime
cache_ok = True
def __init__(self, datetime_processor=PendulumDateTime, *args, **kwargs):
super().__init__(*args, **kwargs)
self.dt_object = datetime_processor()
def _coerce(self, value):
return self.dt_object._coerce(self.impl, value)
def process_bind_param(self, value, dialect):
return self.dt_object.process_bind_param(self.impl, value, dialect)
def process_result_value(self, value, dialect):
return self.dt_object.process_result_value(self.impl, value, dialect)
def process_literal_param(self, value, dialect):
return value
@property
def python_type(self):
return self.impl.type.python_type

View file

@ -0,0 +1,32 @@
from ...exceptions import ImproperlyConfigured
from .pendulum_datetime import PendulumDateTime
pendulum = None
try:
import pendulum
except ImportError:
pass
class PendulumDate(PendulumDateTime):
def __init__(self):
if not pendulum:
raise ImproperlyConfigured(
"'pendulum' package is required to use 'PendulumDate'"
)
def _coerce(self, impl, value):
if value:
if not isinstance(value, pendulum.Date):
value = super()._coerce(impl, value).date()
return value
def process_result_value(self, impl, value, dialect):
if value:
return pendulum.parse(value.isoformat()).date()
return value
def process_bind_param(self, impl, value, dialect):
if value:
return self._coerce(impl, value)
return value

View file

@ -0,0 +1,49 @@
from datetime import datetime
from ...exceptions import ImproperlyConfigured
pendulum = None
try:
import pendulum
except ImportError:
pass
class PendulumDateTime:
def __init__(self):
if not pendulum:
raise ImproperlyConfigured(
"'pendulum' package is required to use 'PendulumDateTime'"
)
def _coerce(self, impl, value):
if value is not None:
if isinstance(value, pendulum.DateTime):
pass
elif isinstance(value, (int, float)):
value = pendulum.from_timestamp(value)
elif isinstance(value, str) and value.isdigit():
value = pendulum.from_timestamp(int(value))
elif isinstance(value, datetime):
value = pendulum.datetime(
value.year,
value.month,
value.day,
value.hour,
value.minute,
value.second,
value.microsecond
)
else:
value = pendulum.parse(value)
return value
def process_bind_param(self, impl, value, dialect):
if value:
return self._coerce(impl, value).in_tz('UTC')
return value
def process_result_value(self, impl, value, dialect):
if value:
return pendulum.parse(value.isoformat())
return value

View file

@ -0,0 +1,52 @@
from ipaddress import ip_address
from sqlalchemy import types
from .scalar_coercible import ScalarCoercible
class IPAddressType(ScalarCoercible, types.TypeDecorator):
"""
Changes IPAddress objects to a string representation on the way in and
changes them back to IPAddress objects on the way out.
::
from sqlalchemy_utils import IPAddressType
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
ip_address = sa.Column(IPAddressType)
user = User()
user.ip_address = '123.123.123.123'
session.add(user)
session.commit()
user.ip_address # IPAddress object
"""
impl = types.Unicode(50)
cache_ok = True
def __init__(self, max_length=50, *args, **kwargs):
super().__init__(*args, **kwargs)
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
return str(value) if value else None
def process_result_value(self, value, dialect):
return ip_address(value) if value else None
def _coerce(self, value):
return ip_address(value) if value else None
@property
def python_type(self):
return self.impl.type.python_type

View file

@ -0,0 +1,77 @@
import json
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql.base import ischema_names
try:
from sqlalchemy.dialects.postgresql import JSON
has_postgres_json = True
except ImportError:
class PostgresJSONType(sa.types.UserDefinedType):
"""
Text search vector type for postgresql.
"""
def get_col_spec(self):
return 'json'
ischema_names['json'] = PostgresJSONType
has_postgres_json = False
class JSONType(sa.types.TypeDecorator):
"""
JSONType offers way of saving JSON data structures to database. On
PostgreSQL the underlying implementation of this data type is 'json' while
on other databases its simply 'text'.
::
from sqlalchemy_utils import JSONType
class Product(Base):
__tablename__ = 'product'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(50))
details = sa.Column(JSONType)
product = Product()
product.details = {
'color': 'red',
'type': 'car',
'max-speed': '400 mph'
}
session.commit()
"""
impl = sa.UnicodeText
hashable = False
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
# Use the native JSON type.
if has_postgres_json:
return dialect.type_descriptor(JSON())
else:
return dialect.type_descriptor(PostgresJSONType())
else:
return dialect.type_descriptor(self.impl)
def process_bind_param(self, value, dialect):
if dialect.name == 'postgresql' and has_postgres_json:
return value
if value is not None:
value = json.dumps(value)
return value
def process_result_value(self, value, dialect):
if dialect.name == 'postgresql':
return value
if value is not None:
value = json.loads(value)
return value

View file

@ -0,0 +1,76 @@
from sqlalchemy import types
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
babel = None
try:
import babel
except ImportError:
pass
class LocaleType(ScalarCoercible, types.TypeDecorator):
"""
LocaleType saves Babel_ Locale objects into database. The Locale objects
are converted to string on the way in and back to object on the way out.
In order to use LocaleType you need to install Babel_ first.
.. _Babel: https://babel.pocoo.org/
::
from sqlalchemy_utils import LocaleType
from babel import Locale
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(50))
locale = sa.Column(LocaleType)
user = User()
user.locale = Locale('en_US')
session.add(user)
session.commit()
Like many other types this type also supports scalar coercion:
::
user.locale = 'de_DE'
user.locale # Locale('de', territory='DE')
"""
impl = types.Unicode(10)
cache_ok = True
def __init__(self):
if babel is None:
raise ImproperlyConfigured(
'Babel packaged is required with LocaleType.'
)
def process_bind_param(self, value, dialect):
if isinstance(value, babel.Locale):
return str(value)
if isinstance(value, str):
return value
def process_result_value(self, value, dialect):
if value is not None:
return babel.Locale.parse(value)
def _coerce(self, value):
if value is not None and not isinstance(value, babel.Locale):
return babel.Locale.parse(value)
return value

View file

@ -0,0 +1,121 @@
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import ARRAY
from sqlalchemy.dialects.postgresql.base import ischema_names, PGTypeCompiler
from sqlalchemy.sql import expression
from ..primitives import Ltree
from .scalar_coercible import ScalarCoercible
class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible):
"""Postgresql LtreeType type.
The LtreeType datatype can be used for representing labels of data stored
in hierarchical tree-like structure. For more detailed information please
refer to https://www.postgresql.org/docs/current/ltree.html
::
from sqlalchemy_utils import LtreeType, Ltree
class DocumentSection(Base):
__tablename__ = 'document_section'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
path = sa.Column(LtreeType)
section = DocumentSection(path=Ltree('Countries.Finland'))
session.add(section)
session.commit()
section.path # Ltree('Countries.Finland')
.. note::
Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types
may require installation of Postgresql ltree extension on the server
side. Please visit https://www.postgresql.org/ for details.
"""
cache_ok = True
class comparator_factory(types.Concatenable.Comparator):
def ancestor_of(self, other):
if isinstance(other, list):
return self.op('@>')(expression.cast(other, ARRAY(LtreeType)))
else:
return self.op('@>')(other)
def descendant_of(self, other):
if isinstance(other, list):
return self.op('<@')(expression.cast(other, ARRAY(LtreeType)))
else:
return self.op('<@')(other)
def lquery(self, other):
if isinstance(other, list):
return self.op('?')(expression.cast(other, ARRAY(LQUERY)))
else:
return self.op('~')(other)
def ltxtquery(self, other):
return self.op('@')(other)
def bind_processor(self, dialect):
def process(value):
if value:
return value.path
return process
def result_processor(self, dialect, coltype):
def process(value):
return self._coerce(value)
return process
def literal_processor(self, dialect):
def process(value):
value = value.replace("'", "''")
return "'%s'" % value
return process
__visit_name__ = 'LTREE'
def _coerce(self, value):
if value:
return Ltree(value)
class LQUERY(types.TypeEngine):
"""Postresql LQUERY type.
See :class:`LTREE` for details.
"""
__visit_name__ = 'LQUERY'
class LTXTQUERY(types.TypeEngine):
"""Postresql LTXTQUERY type.
See :class:`LTREE` for details.
"""
__visit_name__ = 'LTXTQUERY'
ischema_names['ltree'] = LtreeType
ischema_names['lquery'] = LQUERY
ischema_names['ltxtquery'] = LTXTQUERY
def visit_LTREE(self, type_, **kw):
return 'LTREE'
def visit_LQUERY(self, type_, **kw):
return 'LQUERY'
def visit_LTXTQUERY(self, type_, **kw):
return 'LTXTQUERY'
PGTypeCompiler.visit_LTREE = visit_LTREE
PGTypeCompiler.visit_LQUERY = visit_LQUERY
PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY

View file

@ -0,0 +1,259 @@
import weakref
from sqlalchemy import types
from sqlalchemy.dialects import oracle, postgresql, sqlite
from sqlalchemy.ext.mutable import Mutable
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
passlib = None
try:
import passlib
from passlib.context import LazyCryptContext
except ImportError:
pass
class Password(Mutable):
@classmethod
def coerce(cls, key, value):
if isinstance(value, Password):
return value
if isinstance(value, (str, bytes)):
return cls(value, secret=True)
super().coerce(key, value)
def __init__(self, value, context=None, secret=False):
# Store the hash (if it is one).
self.hash = value if not secret else None
# Store the secret if we have one.
self.secret = value if secret else None
# The hash should be bytes.
if isinstance(self.hash, str):
self.hash = self.hash.encode('utf8')
# Save weakref of the password context (if we have one)
self.context = weakref.proxy(context) if context is not None else None
def __eq__(self, value):
if self.hash is None or value is None:
# Ensure that we don't continue comparison if one of us is None.
return self.hash is value
if isinstance(value, Password):
# Comparing 2 hashes isn't very useful; but this equality
# method breaks otherwise.
return value.hash == self.hash
if self.context is None:
# Compare 2 hashes again as we don't know how to validate.
return value == self
if isinstance(value, (str, bytes)):
valid, new = self.context.verify_and_update(value, self.hash)
if valid and new:
# New hash was calculated due to various reasons; stored one
# wasn't optimal, etc.
self.hash = new
# The hash should be bytes.
if isinstance(self.hash, str):
self.hash = self.hash.encode('utf8')
self.changed()
return valid
return False
def __ne__(self, value):
return not (self == value)
class PasswordType(ScalarCoercible, types.TypeDecorator):
"""
PasswordType hashes passwords as they come into the database and allows
verifying them using a Pythonic interface. This Pythonic interface
relies on setting up automatic data type coercion using the
:func:`~sqlalchemy_utils.listeners.force_auto_coercion` function.
All keyword arguments (aside from max_length) are forwarded to the
construction of a `passlib.context.LazyCryptContext` object, which
also supports deferred configuration via the `onload` callback.
The following usage will create a password column that will
automatically hash new passwords as `pbkdf2_sha512` but still compare
passwords against pre-existing `md5_crypt` hashes. As passwords are
compared; the password hash in the database will be updated to
be `pbkdf2_sha512`.
::
class Model(Base):
password = sa.Column(PasswordType(
schemes=[
'pbkdf2_sha512',
'md5_crypt'
],
deprecated=['md5_crypt']
))
Verifying password is as easy as:
::
target = Model()
target.password = 'b'
# '$5$rounds=80000$H.............'
target.password == 'b'
# True
Lazy configuration of the type with Flask config:
::
import flask
from sqlalchemy_utils import PasswordType, force_auto_coercion
force_auto_coercion()
class User(db.Model):
__tablename__ = 'user'
password = db.Column(
PasswordType(
# The returned dictionary is forwarded to the CryptContext
onload=lambda **kwargs: dict(
schemes=flask.current_app.config['PASSWORD_SCHEMES'],
**kwargs
),
),
unique=False,
nullable=False,
)
"""
impl = types.VARBINARY(1024)
python_type = Password
cache_ok = True
def __init__(self, max_length=None, **kwargs):
# Fail if passlib is not found.
if passlib is None:
raise ImproperlyConfigured(
"'passlib' is required to use 'PasswordType'"
)
# Construct the passlib crypt context.
self.context = LazyCryptContext(**kwargs)
self._max_length = max_length
@property
def hashing_method(self):
return (
'hash'
if hasattr(self.context, 'hash')
else 'encrypt'
)
@property
def length(self):
"""Get column length."""
if self._max_length is None:
self._max_length = self.calculate_max_length()
return self._max_length
def calculate_max_length(self):
# Calculate the largest possible encoded password.
# name + rounds + salt + hash + ($ * 4) of largest hash
max_lengths = [1024]
for name in self.context.schemes():
scheme = getattr(__import__('passlib.hash').hash, name)
length = 4 + len(scheme.name)
length += len(str(getattr(scheme, 'max_rounds', '')))
length += (getattr(scheme, 'max_salt_size', 0) or 0)
length += getattr(
scheme,
'encoded_checksum_size',
scheme.checksum_size
)
max_lengths.append(length)
# Return the maximum calculated max length.
return max(max_lengths)
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
# Use a BYTEA type for postgresql.
impl = postgresql.BYTEA(self.length)
elif dialect.name == 'oracle':
# Use a RAW type for oracle.
impl = oracle.RAW(self.length)
elif dialect.name == 'sqlite':
# Use a BLOB type for sqlite
impl = sqlite.BLOB(self.length)
else:
# Use a VARBINARY for all other dialects.
impl = types.VARBINARY(self.length)
return dialect.type_descriptor(impl)
def process_bind_param(self, value, dialect):
if isinstance(value, Password):
# If were given a password secret; hash it.
if value.secret is not None:
return self._hash(value.secret).encode('utf8')
# Value has already been hashed.
return value.hash
if isinstance(value, str):
# Assume value has not been hashed.
return self._hash(value).encode('utf8')
def process_result_value(self, value, dialect):
if value is not None:
return Password(value, self.context)
def _hash(self, value):
return getattr(self.context, self.hashing_method)(value)
def _coerce(self, value):
if value is None:
return
if not isinstance(value, Password):
# Hash the password using the default scheme.
value = self._hash(value).encode('utf8')
return Password(value, context=self.context)
else:
# If were given a password object; ensure the context is right.
value.context = weakref.proxy(self.context)
# If were given a password secret; hash it.
if value.secret is not None:
value.hash = self._hash(value.secret).encode('utf8')
value.secret = None
return value
@property
def python_type(self):
return self.impl.type.python_type
Password.associate_with(PasswordType)

View file

@ -0,0 +1,390 @@
"""
CompositeType provides means to interact with
`PostgreSQL composite types`_. Currently this type features:
* Easy attribute access to composite type fields
* Supports SQLAlchemy TypeDecorator types
* Ability to include composite types as part of PostgreSQL arrays
* Type creation and dropping
Installation
^^^^^^^^^^^^
CompositeType automatically attaches `before_create` and `after_drop` DDL
listeners. These listeners create and drop the composite type in the
database. This means it works out of the box in your test environment where
you create the tables on each test run.
When you already have your database set up you should call
:func:`register_composites` after you've set up all models.
::
register_composites(conn)
Usage
^^^^^
::
from collections import OrderedDict
import sqlalchemy as sa
from sqlalchemy_utils import CompositeType, CurrencyType
class Account(Base):
__tablename__ = 'account'
id = sa.Column(sa.Integer, primary_key=True)
balance = sa.Column(
CompositeType(
'money_type',
[
sa.Column('currency', CurrencyType),
sa.Column('amount', sa.Integer)
]
)
)
Creation
~~~~~~~~
When creating CompositeType, you can either pass in a tuple or a dictionary.
::
account1 = Account()
account1.balance = ('USD', 15)
account2 = Account()
account2.balance = {'currency': 'USD', 'amount': 15}
session.add(account1)
session.add(account2)
session.commit()
Accessing fields
^^^^^^^^^^^^^^^^
CompositeType provides attribute access to underlying fields. In the following
example we find all accounts with balance amount more than 5000.
::
session.query(Account).filter(Account.balance.amount > 5000)
Arrays of composites
^^^^^^^^^^^^^^^^^^^^
::
from sqlalchemy.dialects.postgresql import ARRAY
class Account(Base):
__tablename__ = 'account'
id = sa.Column(sa.Integer, primary_key=True)
balances = sa.Column(
ARRAY(
CompositeType(
'money_type',
[
sa.Column('currency', CurrencyType),
sa.Column('amount', sa.Integer)
]
),
dimensions=1
)
)
.. _PostgreSQL composite types:
https://www.postgresql.org/docs/current/rowtypes.html
Related links:
https://schinckel.net/2014/09/24/using-postgres-composite-types-in-django/
"""
from collections import namedtuple
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.schema import _CreateDropBase
from sqlalchemy.sql.expression import FunctionElement
from sqlalchemy.types import (
SchemaType,
to_instance,
TypeDecorator,
UserDefinedType
)
from .. import ImproperlyConfigured
psycopg2 = None
CompositeCaster = None
adapt = None
AsIs = None
register_adapter = None
try:
import psycopg2
from psycopg2.extensions import adapt, AsIs, register_adapter
from psycopg2.extras import CompositeCaster
except ImportError:
pass
class CompositeElement(FunctionElement):
"""
Instances of this class wrap a Postgres composite type.
"""
def __init__(self, base, field, type_):
self.name = field
self.type = to_instance(type_)
super().__init__(base)
@compiles(CompositeElement)
def _compile_pgelem(expr, compiler, **kw):
return f'({compiler.process(expr.clauses, **kw)}).{expr.name}'
# TODO: Make the registration work on connection level instead of global level
registered_composites = {}
class CompositeType(UserDefinedType, SchemaType):
"""
Represents a PostgreSQL composite type.
:param name:
Name of the composite type.
:param columns:
List of columns that this composite type consists of
"""
python_type = tuple
class comparator_factory(UserDefinedType.Comparator):
def __getattr__(self, key):
try:
type_ = self.type.typemap[key]
except KeyError:
raise KeyError(
"Type '{}' doesn't have an attribute: '{}'".format(
self.name, key
)
)
return CompositeElement(self.expr, key, type_)
def __init__(self, name, columns, quote=None, **kwargs):
if psycopg2 is None:
raise ImproperlyConfigured(
"'psycopg2' package is required in order to use CompositeType."
)
SchemaType.__init__(
self,
name=name,
quote=quote
)
self.columns = columns
if name in registered_composites:
self.type_cls = registered_composites[name].type_cls
else:
self.type_cls = namedtuple(
self.name, [c.name for c in columns]
)
registered_composites[name] = self
class Caster(CompositeCaster):
def make(obj, values):
return self.type_cls(*values)
self.caster = Caster
attach_composite_listeners()
def get_col_spec(self):
return self.name
def bind_processor(self, dialect):
def process(value):
if value is None:
return None
processed_value = []
for i, column in enumerate(self.columns):
current_value = (
value.get(column.name)
if isinstance(value, dict)
else value[i]
)
if isinstance(column.type, TypeDecorator):
processed_value.append(
column.type.process_bind_param(
current_value, dialect
)
)
else:
processed_value.append(current_value)
return self.type_cls(*processed_value)
return process
def result_processor(self, dialect, coltype):
def process(value):
if value is None:
return None
cls = value.__class__
kwargs = {}
for column in self.columns:
if isinstance(column.type, TypeDecorator):
kwargs[column.name] = column.type.process_result_value(
getattr(value, column.name), dialect
)
else:
kwargs[column.name] = getattr(value, column.name)
return cls(**kwargs)
return process
def create(self, bind=None, checkfirst=None):
if (
not checkfirst or
not bind.dialect.has_type(bind, self.name, schema=self.schema)
):
bind.execute(CreateCompositeType(self))
def drop(self, bind=None, checkfirst=True):
if (
checkfirst and
bind.dialect.has_type(bind, self.name, schema=self.schema)
):
bind.execute(DropCompositeType(self))
def register_psycopg2_composite(dbapi_connection, composite):
psycopg2.extras.register_composite(
composite.name,
dbapi_connection,
globally=True,
factory=composite.caster
)
def adapt_composite(value):
dialect = PGDialect_psycopg2()
adapted = [
adapt(
getattr(value, column.name)
if not isinstance(column.type, TypeDecorator)
else column.type.process_bind_param(
getattr(value, column.name),
dialect
)
)
for column in
composite.columns
]
for value in adapted:
if hasattr(value, 'prepare'):
value.prepare(dbapi_connection)
values = [
value.getquoted().decode(dbapi_connection.encoding)
for value in adapted
]
return AsIs(
'({})::{}'.format(
', '.join(values),
dialect.identifier_preparer.quote(composite.name)
)
)
register_adapter(composite.type_cls, adapt_composite)
def get_driver_connection(connection):
try:
# SQLAlchemy 2.0
return connection.connection.driver_connection
except AttributeError:
return connection.connection.connection
def before_create(target, connection, **kw):
for name, composite in registered_composites.items():
composite.create(connection, checkfirst=True)
register_psycopg2_composite(
get_driver_connection(connection),
composite
)
def after_drop(target, connection, **kw):
for name, composite in registered_composites.items():
composite.drop(connection, checkfirst=True)
def register_composites(connection):
for name, composite in registered_composites.items():
register_psycopg2_composite(
get_driver_connection(connection),
composite
)
def attach_composite_listeners():
listeners = [
(sa.MetaData, 'before_create', before_create),
(sa.MetaData, 'after_drop', after_drop),
]
for listener in listeners:
if not sa.event.contains(*listener):
sa.event.listen(*listener)
def remove_composite_listeners():
listeners = [
(sa.MetaData, 'before_create', before_create),
(sa.MetaData, 'after_drop', after_drop),
]
for listener in listeners:
if sa.event.contains(*listener):
sa.event.remove(*listener)
class CreateCompositeType(_CreateDropBase):
pass
@compiles(CreateCompositeType)
def _visit_create_composite_type(create, compiler, **kw):
type_ = create.element
fields = ', '.join(
'{name} {type}'.format(
name=column.name,
type=compiler.dialect.type_compiler.process(
to_instance(column.type)
)
)
for column in type_.columns
)
return 'CREATE TYPE {name} AS ({fields})'.format(
name=compiler.preparer.format_type(type_),
fields=fields
)
class DropCompositeType(_CreateDropBase):
pass
@compiles(DropCompositeType)
def _visit_drop_composite_type(drop, compiler, **kw):
type_ = drop.element
return f'DROP TYPE {compiler.preparer.format_type(type_)}'

View file

@ -0,0 +1,204 @@
"""
.. note::
The `phonenumbers`_ package must be installed to use PhoneNumber types.
.. _phonenumbers: https://github.com/daviddrysdale/python-phonenumbers
"""
from sqlalchemy import exc, types
from ..exceptions import ImproperlyConfigured
from ..utils import str_coercible
from .scalar_coercible import ScalarCoercible
try:
import phonenumbers
from phonenumbers.phonenumber import PhoneNumber as BasePhoneNumber
from phonenumbers.phonenumberutil import NumberParseException
except ImportError:
phonenumbers = None
BasePhoneNumber = object
NumberParseException = Exception
class PhoneNumberParseException(NumberParseException, exc.DontWrapMixin):
"""
Wraps exceptions from phonenumbers with SQLAlchemy's DontWrapMixin
so we get more meaningful exceptions on validation failure instead of the
StatementException
Clients can catch this as either a PhoneNumberParseException or
NumberParseException from the phonenumbers library.
"""
pass
@str_coercible
class PhoneNumber(BasePhoneNumber):
"""
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
different phone number formats to attributes, so they can be easily used
in templates. Phone number validation method is also implemented.
Takes the raw phone number and country code as params and parses them
into a PhoneNumber object.
.. _Python phonenumbers library:
https://github.com/daviddrysdale/python-phonenumbers
::
from sqlalchemy_utils import PhoneNumber
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
_phone_number = sa.Column(sa.Unicode(20))
country_code = sa.Column(sa.Unicode(8))
phone_number = sa.orm.composite(
PhoneNumber,
_phone_number,
country_code
)
user = User(phone_number=PhoneNumber('0401234567', 'FI'))
user.phone_number.e164 # '+358401234567'
user.phone_number.international # '+358 40 1234567'
user.phone_number.national # '040 1234567'
user.country_code # 'FI'
:param raw_number:
String representation of the phone number.
:param region:
Region of the phone number.
:param check_region:
Whether to check the supplied region parameter;
should always be True for external callers.
Can be useful for short codes or toll free
"""
def __init__(self, raw_number, region=None, check_region=True):
# Bail if phonenumbers is not found.
if phonenumbers is None:
raise ImproperlyConfigured(
"The 'phonenumbers' package is required to use 'PhoneNumber'"
)
try:
self._phone_number = phonenumbers.parse(
raw_number, region, _check_region=check_region
)
except NumberParseException as e:
# Wrap exception so SQLAlchemy doesn't swallow it as a
# StatementError
#
# Worth noting that if -1 shows up as the error_type
# it's likely because the API has changed upstream and these
# bindings need to be updated.
raise PhoneNumberParseException(getattr(e, "error_type", -1), str(e))
super().__init__(
country_code=self._phone_number.country_code,
national_number=self._phone_number.national_number,
extension=self._phone_number.extension,
italian_leading_zero=self._phone_number.italian_leading_zero,
raw_input=self._phone_number.raw_input,
country_code_source=self._phone_number.country_code_source,
preferred_domestic_carrier_code=(
self._phone_number.preferred_domestic_carrier_code
),
)
self.region = region
self.national = phonenumbers.format_number(
self._phone_number, phonenumbers.PhoneNumberFormat.NATIONAL
)
self.international = phonenumbers.format_number(
self._phone_number, phonenumbers.PhoneNumberFormat.INTERNATIONAL
)
self.e164 = phonenumbers.format_number(
self._phone_number, phonenumbers.PhoneNumberFormat.E164
)
def __composite_values__(self):
return self.national, self.region
def is_valid_number(self):
return phonenumbers.is_valid_number(self._phone_number)
def __unicode__(self):
return self.national
def __hash__(self):
return hash(self.e164)
class PhoneNumberType(ScalarCoercible, types.TypeDecorator):
"""
Changes PhoneNumber objects to a string representation on the way in and
changes them back to PhoneNumber objects on the way out. If E164 is used
as storing format, no country code is needed for parsing the database
value to PhoneNumber object.
::
class User(self.Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
name = sa.Column(sa.Unicode(255))
phone_number = sa.Column(PhoneNumberType())
user = User(phone_number='+358401234567')
user.phone_number.e164 # '+358401234567'
user.phone_number.international # '+358 40 1234567'
user.phone_number.national # '040 1234567'
"""
STORE_FORMAT = "e164"
impl = types.Unicode(20)
python_type = PhoneNumber
cache_ok = True
def __init__(self, region="US", max_length=20, *args, **kwargs):
# Bail if phonenumbers is not found.
if phonenumbers is None:
raise ImproperlyConfigured(
"The 'phonenumbers' package is required to use 'PhoneNumberType'"
)
super().__init__(*args, **kwargs)
self.region = region
self.impl = types.Unicode(max_length)
def process_bind_param(self, value, dialect):
if value:
if not isinstance(value, PhoneNumber):
value = PhoneNumber(value, region=self.region)
if self.STORE_FORMAT == "e164" and value.extension:
return f"{value.e164};ext={value.extension}"
return getattr(value, self.STORE_FORMAT)
return value
def process_result_value(self, value, dialect):
if value:
return PhoneNumber(value, self.region)
return value
def _coerce(self, value):
if value and not isinstance(value, PhoneNumber):
value = PhoneNumber(value, region=self.region)
return value or None

View file

@ -0,0 +1,480 @@
"""
SQLAlchemy-Utils provides wide variety of range data types. All range data
types return Interval objects of intervals_ package. In order to use range data
types you need to install intervals_ with:
::
pip install intervals
Intervals package provides good chunk of additional interval operators that for
example psycopg2 range objects do not support.
Some good reading for practical interval implementations:
https://wiki.postgresql.org/images/f/f0/Range-types.pdf
Range type initialization
-------------------------
::
from sqlalchemy_utils import IntRangeType
class Event(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
estimated_number_of_persons = sa.Column(IntRangeType)
You can also set a step parameter for range type. The values that are not
multipliers of given step will be rounded up to nearest step multiplier.
::
from sqlalchemy_utils import IntRangeType
class Event(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
estimated_number_of_persons = sa.Column(IntRangeType(step=1000))
event = Event(estimated_number_of_persons=[100, 1200])
event.estimated_number_of_persons.lower # 0
event.estimated_number_of_persons.upper # 1000
Range type operators
--------------------
SQLAlchemy-Utils supports many range type operators. These operators follow the
`intervals` package interval coercion rules.
So for example when we make a query such as:
::
session.query(Car).filter(Car.price_range == 300)
It is essentially the same as:
::
session.query(Car).filter(Car.price_range == DecimalInterval([300, 300]))
Comparison operators
^^^^^^^^^^^^^^^^^^^^
All range types support all comparison operators (>, >=, ==, !=, <=, <).
::
Car.price_range < [12, 300]
Car.price_range == [12, 300]
Car.price_range < 300
Car.price_range > (300, 500)
# Whether or not range is strictly left of another range
Car.price_range << [300, 500]
# Whether or not range is strictly right of another range
Car.price_range >> [300, 500]
Membership operators
^^^^^^^^^^^^^^^^^^^^
::
Car.price_range.contains([300, 500])
Car.price_range.contained_by([300, 500])
Car.price_range.in_([[300, 500], [800, 900]])
~ Car.price_range.in_([[300, 400], [700, 800]])
Length
^^^^^^
SQLAlchemy-Utils provides length property for all range types. The
implementation of this property varies on different range types.
In the following example we find all cars whose price range's length is more
than 500.
::
session.query(Car).filter(
Car.price_range.length > 500
)
.. _intervals: https://github.com/kvesteri/intervals
"""
from collections.abc import Iterable
from datetime import timedelta
import sqlalchemy as sa
from sqlalchemy import types
from sqlalchemy.dialects.postgresql import (
DATERANGE,
INT4RANGE,
INT8RANGE,
NUMRANGE,
TSRANGE
)
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
intervals = None
try:
import intervals
except ImportError:
pass
class RangeComparator(types.TypeEngine.Comparator):
@classmethod
def coerced_func(cls, func):
def operation(self, other, **kwargs):
other = self.coerce_arg(other)
return getattr(types.TypeEngine.Comparator, func)(
self, other, **kwargs
)
return operation
def coerce_arg(self, other):
coerced_types = (
self.type.interval_class.type,
tuple,
list,
str,
)
if isinstance(other, coerced_types):
return self.type.interval_class(other)
return other
def in_(self, other):
if (
isinstance(other, Iterable) and
not isinstance(other, str)
):
other = map(self.coerce_arg, other)
return super().in_(other)
def notin_(self, other):
if (
isinstance(other, Iterable) and
not isinstance(other, str)
):
other = map(self.coerce_arg, other)
return super().notin_(other)
def __rshift__(self, other, **kwargs):
"""
Returns whether or not given interval is strictly right of another
interval.
[a, b] >> [c, d] True, if a > d
"""
other = self.coerce_arg(other)
return self.op('>>')(other)
def __lshift__(self, other, **kwargs):
"""
Returns whether or not given interval is strictly left of another
interval.
[a, b] << [c, d] True, if b < c
"""
other = self.coerce_arg(other)
return self.op('<<')(other)
def contains(self, other, **kwargs):
other = self.coerce_arg(other)
return self.op('@>')(other)
def contained_by(self, other, **kwargs):
other = self.coerce_arg(other)
return self.op('<@')(other)
class DiscreteRangeComparator(RangeComparator):
@property
def length(self):
return sa.func.upper(self.expr) - self.step - sa.func.lower(self.expr)
class IntRangeComparator(DiscreteRangeComparator):
step = 1
class DateRangeComparator(DiscreteRangeComparator):
step = timedelta(days=1)
class ContinuousRangeComparator(RangeComparator):
@property
def length(self):
return sa.func.upper(self.expr) - sa.func.lower(self.expr)
funcs = [
'__eq__',
'__ne__',
'__lt__',
'__le__',
'__gt__',
'__ge__',
]
for func in funcs:
setattr(
RangeComparator,
func,
RangeComparator.coerced_func(func)
)
class RangeType(ScalarCoercible, types.TypeDecorator):
comparator_factory = RangeComparator
def __init__(self, *args, **kwargs):
if intervals is None:
raise ImproperlyConfigured(
'RangeType needs intervals package installed.'
)
self.step = kwargs.pop('step', None)
super().__init__(*args, **kwargs)
def load_dialect_impl(self, dialect):
if dialect.name == 'postgresql':
# Use the native range type for postgres.
return dialect.type_descriptor(self.impl)
else:
# Other drivers don't have native types.
return dialect.type_descriptor(sa.String(255))
def process_bind_param(self, value, dialect):
if value is not None:
return str(value)
return value
def process_result_value(self, value, dialect):
if isinstance(value, str):
factory_func = self.interval_class.from_string
else:
factory_func = self.interval_class
if value is not None:
if self.interval_class.step is not None:
return self.canonicalize_result_value(
factory_func(value, step=self.step)
)
else:
return factory_func(value, step=self.step)
return value
def canonicalize_result_value(self, value):
return intervals.canonicalize(value, True, True)
def _coerce(self, value):
if value is None:
return None
return self.interval_class(value, step=self.step)
class IntRangeType(RangeType):
"""
IntRangeType provides way for saving ranges of integers into database. On
PostgreSQL this type maps to native INT4RANGE type while on other drivers
this maps to simple string column.
Example::
from sqlalchemy_utils import IntRangeType
class Event(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
estimated_number_of_persons = sa.Column(IntRangeType)
party = Event(name='party')
# we estimate the party to contain minium of 10 persons and at max
# 100 persons
party.estimated_number_of_persons = [10, 100]
print party.estimated_number_of_persons
# '10-100'
IntRangeType returns the values as IntInterval objects. These objects
support many arithmetic operators::
meeting = Event(name='meeting')
meeting.estimated_number_of_persons = [20, 40]
total = (
meeting.estimated_number_of_persons +
party.estimated_number_of_persons
)
print total
# '30-140'
"""
impl = INT4RANGE
comparator_factory = IntRangeComparator
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval_class = intervals.IntInterval
class Int8RangeType(RangeType):
"""
Int8RangeType provides way for saving ranges of 8-byte integers into
database. On PostgreSQL this type maps to native INT8RANGE type while on
other drivers this maps to simple string column.
Example::
from sqlalchemy_utils import IntRangeType
class Event(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255))
estimated_number_of_persons = sa.Column(Int8RangeType)
party = Event(name='party')
# we estimate the party to contain minium of 10 persons and at max
# 100 persons
party.estimated_number_of_persons = [10, 100]
print party.estimated_number_of_persons
# '10-100'
Int8RangeType returns the values as IntInterval objects. These objects
support many arithmetic operators::
meeting = Event(name='meeting')
meeting.estimated_number_of_persons = [20, 40]
total = (
meeting.estimated_number_of_persons +
party.estimated_number_of_persons
)
print total
# '30-140'
"""
impl = INT8RANGE
comparator_factory = IntRangeComparator
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval_class = intervals.IntInterval
class DateRangeType(RangeType):
"""
DateRangeType provides way for saving ranges of dates into database. On
PostgreSQL this type maps to native DATERANGE type while on other drivers
this maps to simple string column.
Example::
from sqlalchemy_utils import DateRangeType
class Reservation(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
room_id = sa.Column(sa.Integer))
during = sa.Column(DateRangeType)
"""
impl = DATERANGE
comparator_factory = DateRangeComparator
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval_class = intervals.DateInterval
class NumericRangeType(RangeType):
"""
NumericRangeType provides way for saving ranges of decimals into database.
On PostgreSQL this type maps to native NUMRANGE type while on other drivers
this maps to simple string column.
Example::
from sqlalchemy_utils import NumericRangeType
class Car(Base):
__tablename__ = 'car'
id = sa.Column(sa.Integer, autoincrement=True)
name = sa.Column(sa.Unicode(255)))
price_range = sa.Column(NumericRangeType)
"""
impl = NUMRANGE
comparator_factory = ContinuousRangeComparator
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval_class = intervals.DecimalInterval
class DateTimeRangeType(RangeType):
impl = TSRANGE
comparator_factory = ContinuousRangeComparator
cache_ok = True
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.interval_class = intervals.DateTimeInterval

View file

@ -0,0 +1,6 @@
class ScalarCoercible:
def _coerce(self, value):
raise NotImplementedError
def coercion_listener(self, target, value, oldvalue, initiator):
return self._coerce(value)

View file

@ -0,0 +1,97 @@
import sqlalchemy as sa
from sqlalchemy import types
class ScalarListException(Exception):
pass
class ScalarListType(types.TypeDecorator):
"""
ScalarListType type provides convenient way for saving multiple scalar
values in one column. ScalarListType works like list on python side and
saves the result as comma-separated list in the database (custom separators
can also be used).
Example ::
from sqlalchemy_utils import ScalarListType
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, autoincrement=True)
hobbies = sa.Column(ScalarListType())
user = User()
user.hobbies = ['football', 'ice_hockey']
session.commit()
You can easily set up integer lists too:
::
from sqlalchemy_utils import ScalarListType
class Player(Base):
__tablename__ = 'player'
id = sa.Column(sa.Integer, autoincrement=True)
points = sa.Column(ScalarListType(int))
player = Player()
player.points = [11, 12, 8, 80]
session.commit()
ScalarListType is always stored as text. To use an array field on
PostgreSQL database use variant construct::
from sqlalchemy_utils import ScalarListType
class Player(Base):
__tablename__ = 'player'
id = sa.Column(sa.Integer, autoincrement=True)
points = sa.Column(
ARRAY(Integer).with_variant(ScalarListType(int), 'sqlite')
)
"""
impl = sa.UnicodeText()
cache_ok = True
def __init__(self, coerce_func=str, separator=','):
self.separator = str(separator)
self.coerce_func = coerce_func
def process_bind_param(self, value, dialect):
# Convert list of values to unicode separator-separated list
# Example: [1, 2, 3, 4] -> '1, 2, 3, 4'
if value is not None:
if any(self.separator in str(item) for item in value):
raise ScalarListException(
"List values can't contain string '%s' (its being used as "
"separator. If you wish for scalar list values to contain "
"these strings, use a different separator string.)"
% self.separator
)
return self.separator.join(
map(str, value)
)
def process_result_value(self, value, dialect):
if value is not None:
if value == '':
return []
# coerce each value
return list(map(
self.coerce_func, value.split(self.separator)
))

View file

@ -0,0 +1,102 @@
from sqlalchemy import types
from ..exceptions import ImproperlyConfigured
from .scalar_coercible import ScalarCoercible
class TimezoneType(ScalarCoercible, types.TypeDecorator):
"""
TimezoneType provides a way for saving timezones objects into database.
TimezoneType saves timezone objects as strings on the way in and converts
them back to objects when querying the database.
::
from sqlalchemy_utils import TimezoneType
class User(Base):
__tablename__ = 'user'
# Pass backend='pytz' to change it to use pytz. Other values:
# 'dateutil' (default), and 'zoneinfo'.
timezone = sa.Column(TimezoneType(backend='pytz'))
:param backend: Whether to use 'dateutil', 'pytz' or 'zoneinfo' for
timezones. 'zoneinfo' uses the standard library module in Python 3.9+,
but requires the external 'backports.zoneinfo' package for older
Python versions.
"""
impl = types.Unicode(50)
python_type = None
cache_ok = True
def __init__(self, backend='dateutil'):
self.backend = backend
if backend == 'dateutil':
try:
from dateutil.tz import tzfile
from dateutil.zoneinfo import get_zonefile_instance
self.python_type = tzfile
self._to = get_zonefile_instance().zones.get
self._from = lambda x: str(x._filename)
except ImportError:
raise ImproperlyConfigured(
"'python-dateutil' is required to use the "
"'dateutil' backend for 'TimezoneType'"
)
elif backend == 'pytz':
try:
from pytz import timezone
from pytz.tzinfo import BaseTzInfo
self.python_type = BaseTzInfo
self._to = timezone
self._from = str
except ImportError:
raise ImproperlyConfigured(
"'pytz' is required to use the 'pytz' backend "
"for 'TimezoneType'"
)
elif backend == "zoneinfo":
try:
import zoneinfo
except ImportError:
try:
from backports import zoneinfo
except ImportError:
raise ImproperlyConfigured(
"'backports.zoneinfo' is required to use "
"the 'zoneinfo' backend for 'TimezoneType'"
"on Python version < 3.9"
)
self.python_type = zoneinfo.ZoneInfo
self._to = zoneinfo.ZoneInfo
self._from = str
else:
raise ImproperlyConfigured(
"'pytz', 'dateutil' or 'zoneinfo' are the backends "
"supported for 'TimezoneType'"
)
def _coerce(self, value):
if value is not None and not isinstance(value, self.python_type):
obj = self._to(value)
if obj is None:
raise ValueError("unknown time zone '%s'" % value)
return obj
return value
def process_bind_param(self, value, dialect):
return self._from(self._coerce(value)) if value else None
def process_result_value(self, value, dialect):
return self._to(value) if value else None

View file

@ -0,0 +1,108 @@
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import TSVECTOR
class TSVectorType(sa.types.TypeDecorator):
"""
.. note::
This type is PostgreSQL specific and is not supported by other
dialects.
Provides additional functionality for SQLAlchemy PostgreSQL dialect's
TSVECTOR_ type. This additional functionality includes:
* Vector concatenation
* regconfig constructor parameter which is applied to match function if no
postgresql_regconfig parameter is given
* Provides extensible base for extensions such as SQLAlchemy-Searchable_
.. _TSVECTOR:
https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search
.. _SQLAlchemy-Searchable:
https://www.github.com/kvesteri/sqlalchemy-searchable
::
from sqlalchemy_utils import TSVectorType
class Article(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
search_vector = sa.Column(TSVectorType)
# Find all articles whose name matches 'finland'
session.query(Article).filter(Article.search_vector.match('finland'))
TSVectorType also supports vector concatenation.
::
class Article(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
name_vector = sa.Column(TSVectorType)
content = sa.Column(sa.String)
content_vector = sa.Column(TSVectorType)
# Find all articles whose name or content matches 'finland'
session.query(Article).filter(
(Article.name_vector | Article.content_vector).match('finland')
)
You can configure TSVectorType to use a specific regconfig.
::
class Article(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
name = sa.Column(sa.String(100))
search_vector = sa.Column(
TSVectorType(regconfig='pg_catalog.simple')
)
Now expression such as::
Article.search_vector.match('finland')
Would be equivalent to SQL::
search_vector @@ to_tsquery('pg_catalog.simple', 'finland')
"""
impl = TSVECTOR
cache_ok = True
class comparator_factory(TSVECTOR.Comparator):
def match(self, other, **kwargs):
if 'postgresql_regconfig' not in kwargs:
if 'regconfig' in self.type.options:
kwargs['postgresql_regconfig'] = (
self.type.options['regconfig']
)
return TSVECTOR.Comparator.match(self, other, **kwargs)
def __or__(self, other):
return self.op('||')(other)
def __init__(self, *args, **kwargs):
"""
Initializes new TSVectorType
:param *args: list of column names
:param **kwargs: various other options for this TSVectorType
"""
self.columns = args
self.options = kwargs
super().__init__()

View file

@ -0,0 +1,68 @@
from sqlalchemy import types
from .scalar_coercible import ScalarCoercible
furl = None
try:
from furl import furl
except ImportError:
pass
class URLType(ScalarCoercible, types.TypeDecorator):
"""
URLType stores furl_ objects into database.
.. _furl: https://github.com/gruns/furl
::
from sqlalchemy_utils import URLType
from furl import furl
class User(Base):
__tablename__ = 'user'
id = sa.Column(sa.Integer, primary_key=True)
website = sa.Column(URLType)
user = User(website='www.example.com')
# website is coerced to furl object, hence all nice furl operations
# come available
user.website.args['some_argument'] = '12'
print user.website
# www.example.com?some_argument=12
"""
impl = types.UnicodeText
cache_ok = True
def process_bind_param(self, value, dialect):
if furl is not None and isinstance(value, furl):
return str(value)
if isinstance(value, str):
return value
def process_result_value(self, value, dialect):
if furl is None:
return value
if value is not None:
return furl(value)
def _coerce(self, value):
if furl is None:
return value
if value is not None and not isinstance(value, furl):
return furl(value)
return value
@property
def python_type(self):
return self.impl.type.python_type

View file

@ -0,0 +1,113 @@
import uuid
from sqlalchemy import types, util
from sqlalchemy.dialects import mssql, postgresql
from ..compat import get_sqlalchemy_version
from .scalar_coercible import ScalarCoercible
sqlalchemy_version = get_sqlalchemy_version()
class UUIDType(ScalarCoercible, types.TypeDecorator):
"""
Stores a UUID in the database natively when it can and falls back to
a BINARY(16) or a CHAR(32) when it can't.
::
from sqlalchemy_utils import UUIDType
import uuid
class User(Base):
__tablename__ = 'user'
# Pass `binary=False` to fallback to CHAR instead of BINARY
id = sa.Column(
UUIDType(binary=False),
primary_key=True,
default=uuid.uuid4
)
"""
impl = types.BINARY(16)
python_type = uuid.UUID
cache_ok = True
def __init__(self, binary=True, native=True):
"""
:param binary: Whether to use a BINARY(16) or CHAR(32) fallback.
"""
self.binary = binary
self.native = native
def __repr__(self):
return util.generic_repr(self)
def load_dialect_impl(self, dialect):
if self.native and dialect.name in ('postgresql', 'cockroachdb'):
# Use the native UUID type.
return dialect.type_descriptor(postgresql.UUID())
if dialect.name == 'mssql' and self.native:
# Use the native UNIQUEIDENTIFIER type.
return dialect.type_descriptor(mssql.UNIQUEIDENTIFIER())
else:
# Fallback to either a BINARY or a CHAR.
kind = self.impl if self.binary else types.CHAR(32)
return dialect.type_descriptor(kind)
@staticmethod
def _coerce(value):
if value and not isinstance(value, uuid.UUID):
try:
value = uuid.UUID(value)
except (TypeError, ValueError):
value = uuid.UUID(bytes=value)
return value
# sqlalchemy >= 1.4.30 quotes UUID's automatically.
# It is only necessary to quote UUID's in sqlalchemy < 1.4.30.
if sqlalchemy_version < (1, 4, 30):
def process_literal_param(self, value, dialect):
return f"'{value}'" if value else value
else:
def process_literal_param(self, value, dialect):
return value
def process_bind_param(self, value, dialect):
if value is None:
return value
if not isinstance(value, uuid.UUID):
value = self._coerce(value)
if self.native and dialect.name in (
'postgresql',
'mssql',
'cockroachdb'
):
return str(value)
return value.bytes if self.binary else value.hex
def process_result_value(self, value, dialect):
if value is None:
return value
if self.native and dialect.name in (
'postgresql',
'mssql',
'cockroachdb'
):
if isinstance(value, uuid.UUID):
# Some drivers convert PostgreSQL's uuid values to
# Python's uuid.UUID objects by themselves
return value
return uuid.UUID(value)
return uuid.UUID(bytes=value) if self.binary else uuid.UUID(value)

View file

@ -0,0 +1,82 @@
from sqlalchemy import types
from .. import i18n
from ..exceptions import ImproperlyConfigured
from ..primitives import WeekDay, WeekDays
from .bit import BitType
from .scalar_coercible import ScalarCoercible
class WeekDaysType(types.TypeDecorator, ScalarCoercible):
"""
WeekDaysType offers way of saving WeekDays objects into database. The
WeekDays objects are converted to bit strings on the way in and back to
WeekDays objects on the way out.
In order to use WeekDaysType you need to install Babel_ first.
.. _Babel: https://babel.pocoo.org/
::
from sqlalchemy_utils import WeekDaysType, WeekDays
from babel import Locale
class Schedule(Base):
__tablename__ = 'schedule'
id = sa.Column(sa.Integer, autoincrement=True)
working_days = sa.Column(WeekDaysType)
schedule = Schedule()
schedule.working_days = WeekDays('0001111')
session.add(schedule)
session.commit()
print schedule.working_days # Thursday, Friday, Saturday, Sunday
WeekDaysType also supports scalar coercion:
::
schedule.working_days = '1110000'
schedule.working_days # WeekDays object
"""
impl = BitType(WeekDay.NUM_WEEK_DAYS)
cache_ok = True
def __init__(self, *args, **kwargs):
if i18n.babel is None:
raise ImproperlyConfigured(
"'babel' package is required to use 'WeekDaysType'"
)
super().__init__(*args, **kwargs)
@property
def comparator_factory(self):
return self.impl.comparator_factory
def process_bind_param(self, value, dialect):
if isinstance(value, WeekDays):
value = value.as_bit_string()
if dialect.name == 'mysql':
return bytes(value, 'utf8')
return value
def process_result_value(self, value, dialect):
if value is not None:
return WeekDays(value)
def _coerce(self, value):
if value is not None and not isinstance(value, WeekDays):
return WeekDays(value)
return value

View file

@ -0,0 +1,22 @@
from collections.abc import Iterable
def str_coercible(cls):
def __str__(self):
return self.__unicode__()
cls.__str__ = __str__
return cls
def is_sequence(value):
return (
isinstance(value, Iterable) and not isinstance(value, str)
)
def starts_with(iterable, prefix):
"""
Returns whether or not given iterable starts with given prefix.
"""
return list(iterable)[0:len(prefix)] == list(prefix)

View file

@ -0,0 +1,210 @@
import sqlalchemy as sa
from sqlalchemy.ext import compiler
from sqlalchemy.schema import DDLElement, PrimaryKeyConstraint
from sqlalchemy.sql.expression import ClauseElement, Executable
from sqlalchemy_utils.functions import get_columns
class CreateView(DDLElement):
def __init__(self, name, selectable, materialized=False):
self.name = name
self.selectable = selectable
self.materialized = materialized
@compiler.compiles(CreateView)
def compile_create_materialized_view(element, compiler, **kw):
return 'CREATE {}VIEW {} AS {}'.format(
'MATERIALIZED ' if element.materialized else '',
compiler.dialect.identifier_preparer.quote(element.name),
compiler.sql_compiler.process(element.selectable, literal_binds=True),
)
class DropView(DDLElement):
def __init__(self, name, materialized=False, cascade=True):
self.name = name
self.materialized = materialized
self.cascade = cascade
@compiler.compiles(DropView)
def compile_drop_materialized_view(element, compiler, **kw):
return 'DROP {}VIEW IF EXISTS {} {}'.format(
'MATERIALIZED ' if element.materialized else '',
compiler.dialect.identifier_preparer.quote(element.name),
'CASCADE' if element.cascade else ''
)
def create_table_from_selectable(
name,
selectable,
indexes=None,
metadata=None,
aliases=None,
**kwargs
):
if indexes is None:
indexes = []
if metadata is None:
metadata = sa.MetaData()
if aliases is None:
aliases = {}
args = [
sa.Column(
c.name,
c.type,
key=aliases.get(c.name, c.name),
primary_key=c.primary_key
)
for c in get_columns(selectable)
] + indexes
table = sa.Table(name, metadata, *args, **kwargs)
if not any([c.primary_key for c in get_columns(selectable)]):
table.append_constraint(
PrimaryKeyConstraint(*[c.name for c in get_columns(selectable)])
)
return table
def create_materialized_view(
name,
selectable,
metadata,
indexes=None,
aliases=None
):
""" Create a view on a given metadata
:param name: The name of the view to create.
:param selectable: An SQLAlchemy selectable e.g. a select() statement.
:param metadata:
An SQLAlchemy Metadata instance that stores the features of the
database being described.
:param indexes: An optional list of SQLAlchemy Index instances.
:param aliases:
An optional dictionary containing with keys as column names and values
as column aliases.
Same as for ``create_view`` except that a ``CREATE MATERIALIZED VIEW``
statement is emitted instead of a ``CREATE VIEW``.
"""
table = create_table_from_selectable(
name=name,
selectable=selectable,
indexes=indexes,
metadata=None,
aliases=aliases
)
sa.event.listen(
metadata,
'after_create',
CreateView(name, selectable, materialized=True)
)
@sa.event.listens_for(metadata, 'after_create')
def create_indexes(target, connection, **kw):
for idx in table.indexes:
idx.create(connection)
sa.event.listen(
metadata,
'before_drop',
DropView(name, materialized=True)
)
return table
def create_view(
name,
selectable,
metadata,
cascade_on_drop=True
):
""" Create a view on a given metadata
:param name: The name of the view to create.
:param selectable: An SQLAlchemy selectable e.g. a select() statement.
:param metadata:
An SQLAlchemy Metadata instance that stores the features of the
database being described.
The process for creating a view is similar to the standard way that a
table is constructed, except that a selectable is provided instead of
a set of columns. The view is created once a ``CREATE`` statement is
executed against the supplied metadata (e.g. ``metadata.create_all(..)``),
and dropped when a ``DROP`` is executed against the metadata.
To create a view that performs basic filtering on a table. ::
metadata = MetaData()
users = Table('users', metadata,
Column('id', Integer, primary_key=True),
Column('name', String),
Column('fullname', String),
Column('premium_user', Boolean, default=False),
)
premium_members = select(users).where(users.c.premium_user == True)
# sqlalchemy 1.3:
# premium_members = select([users]).where(users.c.premium_user == True)
create_view('premium_users', premium_members, metadata)
metadata.create_all(engine) # View is created at this point
"""
table = create_table_from_selectable(
name=name,
selectable=selectable,
metadata=None
)
sa.event.listen(metadata, 'after_create', CreateView(name, selectable))
@sa.event.listens_for(metadata, 'after_create')
def create_indexes(target, connection, **kw):
for idx in table.indexes:
idx.create(connection)
sa.event.listen(
metadata,
'before_drop',
DropView(name, cascade=cascade_on_drop)
)
return table
class RefreshMaterializedView(Executable, ClauseElement):
inherit_cache = True
def __init__(self, name, concurrently):
self.name = name
self.concurrently = concurrently
@compiler.compiles(RefreshMaterializedView)
def compile_refresh_materialized_view(element, compiler):
return 'REFRESH MATERIALIZED VIEW {concurrently}{name}'.format(
concurrently='CONCURRENTLY ' if element.concurrently else '',
name=compiler.dialect.identifier_preparer.quote(element.name),
)
def refresh_materialized_view(session, name, concurrently=False):
""" Refreshes an already existing materialized view
:param session: An SQLAlchemy Session instance.
:param name: The name of the materialized view to refresh.
:param concurrently:
Optional flag that causes the ``CONCURRENTLY`` parameter
to be specified when the materialized view is refreshed.
"""
# Since session.execute() bypasses autoflush, we must manually flush in
# order to include newly-created/modified objects in the refresh.
session.flush()
session.execute(RefreshMaterializedView(name, concurrently))