forked from Raiza.dev/EliteBot
Cleaned up the directories
This commit is contained in:
parent
f708506d68
commit
a683fcffea
1340 changed files with 554582 additions and 6840 deletions
|
@ -0,0 +1,104 @@
|
|||
from .aggregates import aggregated # noqa
|
||||
from .asserts import ( # noqa
|
||||
assert_max_length,
|
||||
assert_max_value,
|
||||
assert_min_value,
|
||||
assert_non_nullable,
|
||||
assert_nullable
|
||||
)
|
||||
from .exceptions import ImproperlyConfigured # noqa
|
||||
from .expressions import Asterisk, row_to_json # noqa
|
||||
from .functions import ( # noqa
|
||||
cast_if,
|
||||
create_database,
|
||||
create_mock_engine,
|
||||
database_exists,
|
||||
dependent_objects,
|
||||
drop_database,
|
||||
escape_like,
|
||||
get_bind,
|
||||
get_class_by_table,
|
||||
get_column_key,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
get_fk_constraint_for_columns,
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_primary_keys,
|
||||
get_referencing_foreign_keys,
|
||||
get_tables,
|
||||
get_type,
|
||||
group_foreign_keys,
|
||||
has_changes,
|
||||
has_index,
|
||||
has_unique_index,
|
||||
identity,
|
||||
is_loaded,
|
||||
json_sql,
|
||||
jsonb_sql,
|
||||
merge_references,
|
||||
mock_engine,
|
||||
naturally_equivalent,
|
||||
render_expression,
|
||||
render_statement,
|
||||
table_name
|
||||
)
|
||||
from .generic import generic_relationship # noqa
|
||||
from .i18n import TranslationHybrid # noqa
|
||||
from .listeners import ( # noqa
|
||||
auto_delete_orphans,
|
||||
coercion_listener,
|
||||
force_auto_coercion,
|
||||
force_instant_defaults
|
||||
)
|
||||
from .models import generic_repr, Timestamp # noqa
|
||||
from .observer import observes # noqa
|
||||
from .primitives import Country, Currency, Ltree, WeekDay, WeekDays # noqa
|
||||
from .proxy_dict import proxy_dict, ProxyDict # noqa
|
||||
from .query_chain import QueryChain # noqa
|
||||
from .types import ( # noqa
|
||||
ArrowType,
|
||||
Choice,
|
||||
ChoiceType,
|
||||
ColorType,
|
||||
CompositeType,
|
||||
CountryType,
|
||||
CurrencyType,
|
||||
DateRangeType,
|
||||
DateTimeRangeType,
|
||||
EmailType,
|
||||
EncryptedType,
|
||||
EnrichedDateTimeType,
|
||||
EnrichedDateType,
|
||||
instrumented_list,
|
||||
InstrumentedList,
|
||||
Int8RangeType,
|
||||
IntRangeType,
|
||||
IPAddressType,
|
||||
JSONType,
|
||||
LocaleType,
|
||||
LtreeType,
|
||||
NumericRangeType,
|
||||
Password,
|
||||
PasswordType,
|
||||
PhoneNumber,
|
||||
PhoneNumberParseException,
|
||||
PhoneNumberType,
|
||||
register_composites,
|
||||
remove_composite_listeners,
|
||||
ScalarListException,
|
||||
ScalarListType,
|
||||
StringEncryptedType,
|
||||
TimezoneType,
|
||||
TSVectorType,
|
||||
URLType,
|
||||
UUIDType,
|
||||
WeekDaysType
|
||||
)
|
||||
from .view import ( # noqa
|
||||
create_materialized_view,
|
||||
create_view,
|
||||
refresh_materialized_view
|
||||
)
|
||||
|
||||
__version__ = '0.41.1'
|
|
@ -0,0 +1,576 @@
|
|||
"""
|
||||
SQLAlchemy-Utils provides way of automatically calculating aggregate values of
|
||||
related models and saving them to parent model.
|
||||
|
||||
This solution is inspired by RoR counter cache,
|
||||
`counter_culture`_ and `stackoverflow reply by Michael Bayer`_.
|
||||
|
||||
Why?
|
||||
----
|
||||
|
||||
Many times you may have situations where you need to calculate dynamically some
|
||||
aggregate value for given model. Some simple examples include:
|
||||
|
||||
- Number of products in a catalog
|
||||
- Average rating for movie
|
||||
- Latest forum post
|
||||
- Total price of orders for given customer
|
||||
|
||||
Now all these aggregates can be elegantly implemented with SQLAlchemy
|
||||
column_property_ function. However when your data grows calculating these
|
||||
values on the fly might start to hurt the performance of your application. The
|
||||
more aggregates you are using the more performance penalty you get.
|
||||
|
||||
This module provides way of calculating these values automatically and
|
||||
efficiently at the time of modification rather than on the fly.
|
||||
|
||||
|
||||
Features
|
||||
--------
|
||||
|
||||
* Automatically updates aggregate columns when aggregated values change
|
||||
* Supports aggregate values through arbitrary number levels of relations
|
||||
* Highly optimized: uses single query per transaction per aggregate column
|
||||
* Aggregated columns can be of any data type and use any selectable scalar
|
||||
expression
|
||||
|
||||
|
||||
.. _column_property:
|
||||
https://docs.sqlalchemy.org/en/latest/orm/mapped_sql_expr.html#using-column-property
|
||||
.. _counter_culture: https://github.com/magnusvk/counter_culture
|
||||
.. _stackoverflow reply by Michael Bayer:
|
||||
https://stackoverflow.com/a/13765857/520932
|
||||
|
||||
|
||||
Simple aggregates
|
||||
-----------------
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
|
||||
|
||||
class Thread(Base):
|
||||
__tablename__ = 'thread'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('comments', sa.Column(sa.Integer))
|
||||
def comment_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
comments = sa.orm.relationship(
|
||||
'Comment',
|
||||
backref='thread'
|
||||
)
|
||||
|
||||
|
||||
class Comment(Base):
|
||||
__tablename__ = 'comment'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
content = sa.Column(sa.UnicodeText)
|
||||
thread_id = sa.Column(sa.Integer, sa.ForeignKey(Thread.id))
|
||||
|
||||
|
||||
thread = Thread(name='SQLAlchemy development')
|
||||
thread.comments.append(Comment('Going good!'))
|
||||
thread.comments.append(Comment('Great new features!'))
|
||||
|
||||
session.add(thread)
|
||||
session.commit()
|
||||
|
||||
thread.comment_count # 2
|
||||
|
||||
|
||||
|
||||
Custom aggregate expressions
|
||||
----------------------------
|
||||
|
||||
Aggregate expression can be virtually any SQL expression not just a simple
|
||||
function taking one parameter. You can try things such as subqueries and
|
||||
different kinds of functions.
|
||||
|
||||
In the following example we have a Catalog of products where each catalog
|
||||
knows the net worth of its products.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
|
||||
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('products', sa.Column(sa.Integer))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
products = sa.orm.relationship('Product')
|
||||
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
||||
|
||||
|
||||
Now the net_worth column of Catalog model will be automatically whenever:
|
||||
|
||||
* A new product is added to the catalog
|
||||
* A product is deleted from the catalog
|
||||
* The price of catalog product is changed
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
product1 = Product(name='Some product', price=Decimal(1000))
|
||||
product2 = Product(name='Some other product', price=Decimal(500))
|
||||
|
||||
|
||||
catalog = Catalog(
|
||||
name='My first catalog',
|
||||
products=[
|
||||
product1,
|
||||
product2
|
||||
]
|
||||
)
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
|
||||
session.refresh(catalog)
|
||||
catalog.net_worth # 1500
|
||||
|
||||
session.delete(product2)
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
|
||||
catalog.net_worth # 1000
|
||||
|
||||
product1.price = 2000
|
||||
session.commit()
|
||||
session.refresh(catalog)
|
||||
|
||||
catalog.net_worth # 2000
|
||||
|
||||
|
||||
|
||||
|
||||
Multiple aggregates per class
|
||||
-----------------------------
|
||||
|
||||
Sometimes you may need to define multiple aggregate values for same class. If
|
||||
you need to define lots of relationships pointing to same class, remember to
|
||||
define the relationships as viewonly when possible.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
|
||||
|
||||
class Customer(Base):
|
||||
__tablename__ = 'customer'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('orders', sa.Column(sa.Integer))
|
||||
def orders_sum(self):
|
||||
return sa.func.sum(Order.price)
|
||||
|
||||
@aggregated('invoiced_orders', sa.Column(sa.Integer))
|
||||
def invoiced_orders_sum(self):
|
||||
return sa.func.sum(Order.price)
|
||||
|
||||
orders = sa.orm.relationship('Order')
|
||||
|
||||
invoiced_orders = sa.orm.relationship(
|
||||
'Order',
|
||||
primaryjoin=
|
||||
'sa.and_(Order.customer_id == Customer.id, Order.invoiced)',
|
||||
viewonly=True
|
||||
)
|
||||
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
invoiced = sa.Column(sa.Boolean, default=False)
|
||||
customer_id = sa.Column(sa.Integer, sa.ForeignKey(Customer.id))
|
||||
|
||||
|
||||
Many-to-Many aggregates
|
||||
-----------------------
|
||||
|
||||
Aggregate expressions also support many-to-many relationships. The usual use
|
||||
scenarios includes things such as:
|
||||
|
||||
1. Friend count of a user
|
||||
2. Group count where given user belongs to
|
||||
|
||||
::
|
||||
|
||||
|
||||
user_group = sa.Table('user_group', Base.metadata,
|
||||
sa.Column('user_id', sa.Integer, sa.ForeignKey('user.id')),
|
||||
sa.Column('group_id', sa.Integer, sa.ForeignKey('group.id'))
|
||||
)
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('groups', sa.Column(sa.Integer, default=0))
|
||||
def group_count(self):
|
||||
return sa.func.count('1')
|
||||
|
||||
groups = sa.orm.relationship(
|
||||
'Group',
|
||||
backref='users',
|
||||
secondary=user_group
|
||||
)
|
||||
|
||||
|
||||
class Group(Base):
|
||||
__tablename__ = 'group'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
|
||||
|
||||
user = User(name='John Matrix')
|
||||
user.groups = [Group(name='Group A'), Group(name='Group B')]
|
||||
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
session.refresh(user)
|
||||
user.group_count # 2
|
||||
|
||||
|
||||
Multi-level aggregates
|
||||
----------------------
|
||||
|
||||
Aggregates can span across multiple relationships. In the following example
|
||||
each Catalog has a net_worth which is the sum of all products in all
|
||||
categories.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
|
||||
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('categories.products', sa.Column(sa.Integer))
|
||||
def net_worth(self):
|
||||
return sa.func.sum(Product.price)
|
||||
|
||||
categories = sa.orm.relationship('Category')
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey(Catalog.id))
|
||||
|
||||
products = sa.orm.relationship('Product')
|
||||
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey(Category.id))
|
||||
|
||||
|
||||
Examples
|
||||
--------
|
||||
|
||||
Average movie rating
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import aggregated
|
||||
|
||||
|
||||
class Movie(Base):
|
||||
__tablename__ = 'movie'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@aggregated('ratings', sa.Column(sa.Numeric))
|
||||
def avg_rating(self):
|
||||
return sa.func.avg(Rating.stars)
|
||||
|
||||
ratings = sa.orm.relationship('Rating')
|
||||
|
||||
|
||||
class Rating(Base):
|
||||
__tablename__ = 'rating'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
stars = sa.Column(sa.Integer)
|
||||
|
||||
movie_id = sa.Column(sa.Integer, sa.ForeignKey(Movie.id))
|
||||
|
||||
|
||||
movie = Movie('Terminator 2')
|
||||
movie.ratings.append(Rating(stars=5))
|
||||
movie.ratings.append(Rating(stars=4))
|
||||
movie.ratings.append(Rating(stars=3))
|
||||
session.add(movie)
|
||||
session.commit()
|
||||
|
||||
movie.avg_rating # 4
|
||||
|
||||
|
||||
|
||||
TODO
|
||||
----
|
||||
|
||||
* Special consideration should be given to `deadlocks`_.
|
||||
|
||||
|
||||
.. _deadlocks:
|
||||
https://mina.naguib.ca/blog/2010/11/22/postgresql-foreign-key-deadlocks.html
|
||||
|
||||
"""
|
||||
|
||||
|
||||
from collections import defaultdict
|
||||
from weakref import WeakKeyDictionary
|
||||
|
||||
import sqlalchemy as sa
|
||||
import sqlalchemy.event
|
||||
import sqlalchemy.orm
|
||||
from sqlalchemy.ext.declarative import declared_attr
|
||||
from sqlalchemy.sql.functions import _FunctionGenerator
|
||||
|
||||
from .compat import _select_args, get_scalar_subquery
|
||||
from .functions.orm import get_column_key
|
||||
from .relationships import (
|
||||
chained_join,
|
||||
path_to_relationships,
|
||||
select_correlated_expression
|
||||
)
|
||||
|
||||
aggregated_attrs = WeakKeyDictionary()
|
||||
|
||||
|
||||
class AggregatedAttribute(declared_attr):
|
||||
def __init__(
|
||||
self,
|
||||
fget,
|
||||
relationship,
|
||||
column,
|
||||
*args,
|
||||
**kwargs
|
||||
):
|
||||
super().__init__(fget, *args, **kwargs)
|
||||
self.__doc__ = fget.__doc__
|
||||
self.column = column
|
||||
self.relationship = relationship
|
||||
|
||||
def __get__(desc, self, cls):
|
||||
value = (desc.fget, desc.relationship, desc.column)
|
||||
if cls not in aggregated_attrs:
|
||||
aggregated_attrs[cls] = [value]
|
||||
else:
|
||||
aggregated_attrs[cls].append(value)
|
||||
return desc.column
|
||||
|
||||
|
||||
def local_condition(prop, objects):
|
||||
pairs = prop.local_remote_pairs
|
||||
if prop.secondary is not None:
|
||||
parent_column = pairs[1][0]
|
||||
fetched_column = pairs[1][0]
|
||||
else:
|
||||
parent_column = pairs[0][0]
|
||||
fetched_column = pairs[0][1]
|
||||
|
||||
key = get_column_key(prop.mapper, fetched_column)
|
||||
|
||||
values = []
|
||||
for obj in objects:
|
||||
try:
|
||||
values.append(getattr(obj, key))
|
||||
except sa.orm.exc.ObjectDeletedError:
|
||||
pass
|
||||
|
||||
if values:
|
||||
return parent_column.in_(values)
|
||||
|
||||
|
||||
def aggregate_expression(expr, class_):
|
||||
if isinstance(expr, sa.sql.visitors.Visitable):
|
||||
return expr
|
||||
elif isinstance(expr, _FunctionGenerator):
|
||||
return expr(sa.sql.text('1'))
|
||||
else:
|
||||
return expr(class_)
|
||||
|
||||
|
||||
class AggregatedValue:
|
||||
def __init__(self, class_, attr, path, expr):
|
||||
self.class_ = class_
|
||||
self.attr = attr
|
||||
self.path = path
|
||||
self.relationships = list(
|
||||
reversed(path_to_relationships(path, class_))
|
||||
)
|
||||
self.expr = aggregate_expression(expr, class_)
|
||||
|
||||
@property
|
||||
def aggregate_query(self):
|
||||
query = select_correlated_expression(
|
||||
self.class_,
|
||||
self.expr,
|
||||
self.path,
|
||||
self.relationships[0].mapper.class_
|
||||
)
|
||||
|
||||
return get_scalar_subquery(query)
|
||||
|
||||
def update_query(self, objects):
|
||||
table = self.class_.__table__
|
||||
query = table.update().values(
|
||||
{self.attr: self.aggregate_query}
|
||||
)
|
||||
if len(self.relationships) == 1:
|
||||
prop = self.relationships[-1].property
|
||||
condition = local_condition(prop, objects)
|
||||
if condition is not None:
|
||||
return query.where(condition)
|
||||
else:
|
||||
# Builds query such as:
|
||||
#
|
||||
# UPDATE catalog SET product_count = (aggregate_query)
|
||||
# WHERE id IN (
|
||||
# SELECT catalog_id
|
||||
# FROM category
|
||||
# INNER JOIN sub_category
|
||||
# ON category.id = sub_category.category_id
|
||||
# WHERE sub_category.id IN (product_sub_category_ids)
|
||||
# )
|
||||
property_ = self.relationships[-1].property
|
||||
remote_pairs = property_.local_remote_pairs
|
||||
local = remote_pairs[0][0]
|
||||
remote = remote_pairs[0][1]
|
||||
condition = local_condition(
|
||||
self.relationships[0].property,
|
||||
objects
|
||||
)
|
||||
if condition is not None:
|
||||
return query.where(
|
||||
local.in_(
|
||||
sa.select(
|
||||
*_select_args(remote)
|
||||
).select_from(
|
||||
chained_join(*reversed(self.relationships))
|
||||
).where(
|
||||
condition
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class AggregationManager:
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.generator_registry = defaultdict(list)
|
||||
|
||||
def register_listeners(self):
|
||||
sa.event.listen(
|
||||
sa.orm.Mapper,
|
||||
'after_configured',
|
||||
self.update_generator_registry
|
||||
)
|
||||
sa.event.listen(
|
||||
sa.orm.session.Session,
|
||||
'after_flush',
|
||||
self.construct_aggregate_queries
|
||||
)
|
||||
|
||||
def update_generator_registry(self):
|
||||
for class_, attrs in aggregated_attrs.items():
|
||||
for expr, path, column in attrs:
|
||||
value = AggregatedValue(
|
||||
class_=class_,
|
||||
attr=column,
|
||||
path=path,
|
||||
expr=expr(class_)
|
||||
)
|
||||
key = value.relationships[0].mapper.class_
|
||||
self.generator_registry[key].append(
|
||||
value
|
||||
)
|
||||
|
||||
def construct_aggregate_queries(self, session, ctx):
|
||||
object_dict = defaultdict(list)
|
||||
for obj in session:
|
||||
for class_ in self.generator_registry:
|
||||
if isinstance(obj, class_):
|
||||
object_dict[class_].append(obj)
|
||||
|
||||
for class_, objects in object_dict.items():
|
||||
for aggregate_value in self.generator_registry[class_]:
|
||||
query = aggregate_value.update_query(objects)
|
||||
if query is not None:
|
||||
session.execute(query)
|
||||
|
||||
|
||||
manager = AggregationManager()
|
||||
manager.register_listeners()
|
||||
|
||||
|
||||
def aggregated(
|
||||
relationship,
|
||||
column
|
||||
):
|
||||
"""
|
||||
Decorator that generates an aggregated attribute. The decorated function
|
||||
should return an aggregate select expression.
|
||||
|
||||
:param relationship:
|
||||
Defines the relationship of which the aggregate is calculated from.
|
||||
The class needs to have given relationship in order to calculate the
|
||||
aggregate.
|
||||
:param column:
|
||||
SQLAlchemy Column object. The column definition of this aggregate
|
||||
attribute.
|
||||
"""
|
||||
def wraps(func):
|
||||
return AggregatedAttribute(
|
||||
func,
|
||||
relationship,
|
||||
column
|
||||
)
|
||||
return wraps
|
|
@ -0,0 +1,182 @@
|
|||
"""
|
||||
The functions in this module can be used for testing that the constraints of
|
||||
your models. Each assert function runs SQL UPDATEs that check for the existence
|
||||
of given constraint. Consider the following model::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(200), nullable=True)
|
||||
email = sa.Column(sa.String(255), nullable=False)
|
||||
|
||||
|
||||
user = User(name='John Doe', email='john@example.com')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
We can easily test the constraints by assert_* functions::
|
||||
|
||||
|
||||
from sqlalchemy_utils import (
|
||||
assert_nullable,
|
||||
assert_non_nullable,
|
||||
assert_max_length
|
||||
)
|
||||
|
||||
assert_nullable(user, 'name')
|
||||
assert_non_nullable(user, 'email')
|
||||
assert_max_length(user, 'name', 200)
|
||||
|
||||
# raises AssertionError because the max length of email is 255
|
||||
assert_max_length(user, 'email', 300)
|
||||
"""
|
||||
from decimal import Decimal
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.exc import DataError, IntegrityError
|
||||
|
||||
|
||||
def _update_field(obj, field, value):
|
||||
session = sa.orm.object_session(obj)
|
||||
column = sa.inspect(obj.__class__).columns[field]
|
||||
query = column.table.update().values(**{column.key: value})
|
||||
session.execute(query)
|
||||
session.flush()
|
||||
|
||||
|
||||
def _expect_successful_update(obj, field, value, reraise_exc):
|
||||
try:
|
||||
_update_field(obj, field, value)
|
||||
except (reraise_exc) as e:
|
||||
session = sa.orm.object_session(obj)
|
||||
session.rollback()
|
||||
assert False, str(e)
|
||||
|
||||
|
||||
def _expect_failing_update(obj, field, value, expected_exc):
|
||||
try:
|
||||
_update_field(obj, field, value)
|
||||
except expected_exc:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError('Expected update to raise %s' % expected_exc)
|
||||
finally:
|
||||
session = sa.orm.object_session(obj)
|
||||
session.rollback()
|
||||
|
||||
|
||||
def _repeated_value(type_):
|
||||
if isinstance(type_, ARRAY):
|
||||
if isinstance(type_.item_type, sa.Integer):
|
||||
return [0]
|
||||
elif isinstance(type_.item_type, sa.String):
|
||||
return ['a']
|
||||
elif isinstance(type_.item_type, sa.Numeric):
|
||||
return [Decimal('0')]
|
||||
else:
|
||||
raise TypeError('Unknown array item type')
|
||||
else:
|
||||
return 'a'
|
||||
|
||||
|
||||
def _expected_exception(type_):
|
||||
if isinstance(type_, ARRAY):
|
||||
return IntegrityError
|
||||
else:
|
||||
return DataError
|
||||
|
||||
|
||||
def assert_nullable(obj, column):
|
||||
"""
|
||||
Assert that given column is nullable. This is checked by running an SQL
|
||||
update that assigns given column as None.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
"""
|
||||
_expect_successful_update(obj, column, None, IntegrityError)
|
||||
|
||||
|
||||
def assert_non_nullable(obj, column):
|
||||
"""
|
||||
Assert that given column is not nullable. This is checked by running an SQL
|
||||
update that assigns given column as None.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
"""
|
||||
_expect_failing_update(obj, column, None, IntegrityError)
|
||||
|
||||
|
||||
def assert_max_length(obj, column, max_length):
|
||||
"""
|
||||
Assert that the given column is of given max length. This function supports
|
||||
string typed columns as well as PostgreSQL array typed columns.
|
||||
|
||||
In the following example we add a check constraint that user can have a
|
||||
maximum of 5 favorite colors and then test this.::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
favorite_colors = sa.Column(ARRAY(sa.String), nullable=False)
|
||||
__table_args__ = (
|
||||
sa.CheckConstraint(
|
||||
sa.func.array_length(favorite_colors, 1) <= 5
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
user = User(name='John Doe', favorite_colors=['red', 'blue'])
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
assert_max_length(user, 'favorite_colors', 5)
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
:param max_length: Maximum length of given column
|
||||
"""
|
||||
type_ = sa.inspect(obj.__class__).columns[column].type
|
||||
_expect_successful_update(
|
||||
obj,
|
||||
column,
|
||||
_repeated_value(type_) * max_length,
|
||||
_expected_exception(type_)
|
||||
)
|
||||
_expect_failing_update(
|
||||
obj,
|
||||
column,
|
||||
_repeated_value(type_) * (max_length + 1),
|
||||
_expected_exception(type_)
|
||||
)
|
||||
|
||||
|
||||
def assert_min_value(obj, column, min_value):
|
||||
"""
|
||||
Assert that the given column must have a minimum value of `min_value`.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
:param min_value: The minimum allowed value for given column
|
||||
"""
|
||||
_expect_successful_update(obj, column, min_value, IntegrityError)
|
||||
_expect_failing_update(obj, column, min_value - 1, IntegrityError)
|
||||
|
||||
|
||||
def assert_max_value(obj, column, min_value):
|
||||
"""
|
||||
Assert that the given column must have a minimum value of `max_value`.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param column: Name of the column
|
||||
:param max_value: The maximum allowed value for given column
|
||||
"""
|
||||
_expect_successful_update(obj, column, min_value, IntegrityError)
|
||||
_expect_failing_update(obj, column, min_value + 1, IntegrityError)
|
|
@ -0,0 +1,86 @@
|
|||
import re
|
||||
import sys
|
||||
|
||||
if sys.version_info >= (3, 8):
|
||||
from importlib.metadata import metadata
|
||||
else:
|
||||
from importlib_metadata import metadata
|
||||
|
||||
|
||||
def get_sqlalchemy_version(version=metadata("sqlalchemy")["Version"]):
|
||||
"""Extract the sqlalchemy version as a tuple of integers."""
|
||||
|
||||
match = re.search(r"^(\d+)(?:\.(\d+)(?:\.(\d+))?)?", version)
|
||||
try:
|
||||
return tuple(int(v) for v in match.groups() if v is not None)
|
||||
except AttributeError:
|
||||
return ()
|
||||
|
||||
|
||||
_sqlalchemy_version = get_sqlalchemy_version()
|
||||
|
||||
|
||||
# In sqlalchemy 2.0, some functions moved to sqlalchemy.orm.
|
||||
# In sqlalchemy 1.3, they are only available in .ext.declarative.
|
||||
# In sqlalchemy 1.4, they are available in both places.
|
||||
#
|
||||
# WARNING
|
||||
# -------
|
||||
#
|
||||
# These imports are for internal, private compatibility.
|
||||
# They are not supported and may change or move at any time.
|
||||
# Do not import these in your own code.
|
||||
#
|
||||
|
||||
if _sqlalchemy_version >= (1, 4):
|
||||
from sqlalchemy.orm import declarative_base as _declarative_base
|
||||
from sqlalchemy.orm import synonym_for as _synonym_for
|
||||
else:
|
||||
from sqlalchemy.ext.declarative import \
|
||||
declarative_base as _declarative_base
|
||||
from sqlalchemy.ext.declarative import synonym_for as _synonym_for
|
||||
|
||||
|
||||
# scalar subqueries
|
||||
if _sqlalchemy_version >= (1, 4):
|
||||
def get_scalar_subquery(query):
|
||||
return query.scalar_subquery()
|
||||
else:
|
||||
def get_scalar_subquery(query):
|
||||
return query.as_scalar()
|
||||
|
||||
|
||||
# In sqlalchemy 2.0, select() columns are positional.
|
||||
# In sqlalchemy 1.3, select() columns must be wrapped in a list.
|
||||
#
|
||||
# _select_args() is designed so its return value can be unpacked:
|
||||
#
|
||||
# select(*_select_args(1, 2))
|
||||
#
|
||||
# When sqlalchemy 1.3 support is dropped, remove the call to _select_args()
|
||||
# and keep the arguments the same:
|
||||
#
|
||||
# select(1, 2)
|
||||
#
|
||||
# WARNING
|
||||
# -------
|
||||
#
|
||||
# _select_args() is a private, internal function.
|
||||
# It is not supported and may change or move at any time.
|
||||
# Do not import this in your own code.
|
||||
#
|
||||
if _sqlalchemy_version >= (1, 4):
|
||||
def _select_args(*args):
|
||||
return args
|
||||
else:
|
||||
def _select_args(*args):
|
||||
return [args]
|
||||
|
||||
|
||||
__all__ = (
|
||||
"_declarative_base",
|
||||
"get_scalar_subquery",
|
||||
"get_sqlalchemy_version",
|
||||
"_select_args",
|
||||
"_synonym_for",
|
||||
)
|
|
@ -0,0 +1,10 @@
|
|||
"""
|
||||
Global SQLAlchemy-Utils exception classes.
|
||||
"""
|
||||
|
||||
|
||||
class ImproperlyConfigured(Exception):
|
||||
"""
|
||||
SQLAlchemy-Utils is improperly configured; normally due to usage of
|
||||
a utility that depends on a missing library.
|
||||
"""
|
|
@ -0,0 +1,60 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.sql.expression import ColumnElement, FunctionElement
|
||||
from sqlalchemy.sql.functions import GenericFunction
|
||||
|
||||
from .functions.orm import quote
|
||||
|
||||
|
||||
class array_get(FunctionElement):
|
||||
name = 'array_get'
|
||||
|
||||
|
||||
@compiles(array_get)
|
||||
def compile_array_get(element, compiler, **kw):
|
||||
args = list(element.clauses)
|
||||
if len(args) != 2:
|
||||
raise Exception(
|
||||
"Function 'array_get' expects two arguments (%d given)." %
|
||||
len(args)
|
||||
)
|
||||
|
||||
if not hasattr(args[1], 'value') or not isinstance(args[1].value, int):
|
||||
raise Exception(
|
||||
"Second argument should be an integer."
|
||||
)
|
||||
return '({})[{}]'.format(
|
||||
compiler.process(args[0]),
|
||||
sa.text(str(args[1].value + 1))
|
||||
)
|
||||
|
||||
|
||||
class row_to_json(GenericFunction):
|
||||
name = 'row_to_json'
|
||||
type = postgresql.JSON
|
||||
|
||||
|
||||
@compiles(row_to_json, 'postgresql')
|
||||
def compile_row_to_json(element, compiler, **kw):
|
||||
return f"{element.name}({compiler.process(element.clauses)})"
|
||||
|
||||
|
||||
class json_array_length(GenericFunction):
|
||||
name = 'json_array_length'
|
||||
type = sa.Integer
|
||||
|
||||
|
||||
@compiles(json_array_length, 'postgresql')
|
||||
def compile_json_array_length(element, compiler, **kw):
|
||||
return f"{element.name}({compiler.process(element.clauses)})"
|
||||
|
||||
|
||||
class Asterisk(ColumnElement):
|
||||
def __init__(self, selectable):
|
||||
self.selectable = selectable
|
||||
|
||||
|
||||
@compiles(Asterisk)
|
||||
def compile_asterisk(element, compiler, **kw):
|
||||
return '%s.*' % quote(compiler.dialect, element.selectable.name)
|
|
@ -0,0 +1,42 @@
|
|||
from .database import ( # noqa
|
||||
create_database,
|
||||
database_exists,
|
||||
drop_database,
|
||||
escape_like,
|
||||
has_index,
|
||||
has_unique_index,
|
||||
is_auto_assigned_date_column,
|
||||
json_sql,
|
||||
jsonb_sql
|
||||
)
|
||||
from .foreign_keys import ( # noqa
|
||||
dependent_objects,
|
||||
get_fk_constraint_for_columns,
|
||||
get_referencing_foreign_keys,
|
||||
group_foreign_keys,
|
||||
merge_references,
|
||||
non_indexed_foreign_keys
|
||||
)
|
||||
from .mock import create_mock_engine, mock_engine # noqa
|
||||
from .orm import ( # noqa
|
||||
cast_if,
|
||||
get_bind,
|
||||
get_class_by_table,
|
||||
get_column_key,
|
||||
get_columns,
|
||||
get_declarative_base,
|
||||
get_hybrid_properties,
|
||||
get_mapper,
|
||||
get_primary_keys,
|
||||
get_tables,
|
||||
get_type,
|
||||
getdotattr,
|
||||
has_changes,
|
||||
identity,
|
||||
is_loaded,
|
||||
naturally_equivalent,
|
||||
quote,
|
||||
table_name
|
||||
)
|
||||
from .render import render_expression, render_statement # noqa
|
||||
from .sort_query import make_order_by_deterministic # noqa
|
|
@ -0,0 +1,659 @@
|
|||
import itertools
|
||||
import os
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import copy
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.url import make_url
|
||||
from sqlalchemy.exc import OperationalError, ProgrammingError
|
||||
|
||||
from ..utils import starts_with
|
||||
from .orm import quote
|
||||
|
||||
|
||||
def escape_like(string, escape_char='*'):
|
||||
"""
|
||||
Escape the string paremeter used in SQL LIKE expressions.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import escape_like
|
||||
|
||||
|
||||
query = session.query(User).filter(
|
||||
User.name.ilike(escape_like('John'))
|
||||
)
|
||||
|
||||
|
||||
:param string: a string to escape
|
||||
:param escape_char: escape character
|
||||
"""
|
||||
return (
|
||||
string
|
||||
.replace(escape_char, escape_char * 2)
|
||||
.replace('%', escape_char + '%')
|
||||
.replace('_', escape_char + '_')
|
||||
)
|
||||
|
||||
|
||||
def json_sql(value, scalars_to_json=True):
|
||||
"""
|
||||
Convert python data structures to PostgreSQL specific SQLAlchemy JSON
|
||||
constructs. This function is extremly useful if you need to build
|
||||
PostgreSQL JSON on python side.
|
||||
|
||||
.. note::
|
||||
|
||||
This function needs PostgreSQL >= 9.4
|
||||
|
||||
Scalars are converted to to_json SQLAlchemy function objects
|
||||
|
||||
::
|
||||
|
||||
json_sql(1) # Equals SQL: to_json(1)
|
||||
|
||||
json_sql('a') # to_json('a')
|
||||
|
||||
|
||||
Mappings are converted to json_build_object constructs
|
||||
|
||||
::
|
||||
|
||||
json_sql({'a': 'c', '2': 5}) # json_build_object('a', 'c', '2', 5)
|
||||
|
||||
|
||||
Sequences (other than strings) are converted to json_build_array constructs
|
||||
|
||||
::
|
||||
|
||||
json_sql([1, 2, 3]) # json_build_array(1, 2, 3)
|
||||
|
||||
|
||||
You can also nest these data structures
|
||||
|
||||
::
|
||||
|
||||
json_sql({'a': [1, 2, 3]})
|
||||
# json_build_object('a', json_build_array[1, 2, 3])
|
||||
|
||||
|
||||
:param value:
|
||||
value to be converted to SQLAlchemy PostgreSQL function constructs
|
||||
"""
|
||||
scalar_convert = sa.text
|
||||
if scalars_to_json:
|
||||
def scalar_convert(a):
|
||||
return sa.func.to_json(sa.text(a))
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return sa.func.json_build_object(
|
||||
*(
|
||||
json_sql(v, scalars_to_json=False)
|
||||
for v in itertools.chain(*value.items())
|
||||
)
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
return scalar_convert(f"'{value}'")
|
||||
elif isinstance(value, Sequence):
|
||||
return sa.func.json_build_array(
|
||||
*(
|
||||
json_sql(v, scalars_to_json=False)
|
||||
for v in value
|
||||
)
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
return scalar_convert(str(value))
|
||||
return value
|
||||
|
||||
|
||||
def jsonb_sql(value, scalars_to_jsonb=True):
|
||||
"""
|
||||
Convert python data structures to PostgreSQL specific SQLAlchemy JSONB
|
||||
constructs. This function is extremly useful if you need to build
|
||||
PostgreSQL JSONB on python side.
|
||||
|
||||
.. note::
|
||||
|
||||
This function needs PostgreSQL >= 9.4
|
||||
|
||||
Scalars are converted to to_jsonb SQLAlchemy function objects
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql(1) # Equals SQL: to_jsonb(1)
|
||||
|
||||
jsonb_sql('a') # to_jsonb('a')
|
||||
|
||||
|
||||
Mappings are converted to jsonb_build_object constructs
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql({'a': 'c', '2': 5}) # jsonb_build_object('a', 'c', '2', 5)
|
||||
|
||||
|
||||
Sequences (other than strings) converted to jsonb_build_array constructs
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql([1, 2, 3]) # jsonb_build_array(1, 2, 3)
|
||||
|
||||
|
||||
You can also nest these data structures
|
||||
|
||||
::
|
||||
|
||||
jsonb_sql({'a': [1, 2, 3]})
|
||||
# jsonb_build_object('a', jsonb_build_array[1, 2, 3])
|
||||
|
||||
|
||||
:param value:
|
||||
value to be converted to SQLAlchemy PostgreSQL function constructs
|
||||
:boolean jsonbb:
|
||||
Flag to alternatively convert the return with a to_jsonb construct
|
||||
"""
|
||||
scalar_convert = sa.text
|
||||
if scalars_to_jsonb:
|
||||
def scalar_convert(a):
|
||||
return sa.func.to_jsonb(sa.text(a))
|
||||
|
||||
if isinstance(value, Mapping):
|
||||
return sa.func.jsonb_build_object(
|
||||
*(
|
||||
jsonb_sql(v, scalars_to_jsonb=False)
|
||||
for v in itertools.chain(*value.items())
|
||||
)
|
||||
)
|
||||
elif isinstance(value, str):
|
||||
return scalar_convert(f"'{value}'")
|
||||
elif isinstance(value, Sequence):
|
||||
return sa.func.jsonb_build_array(
|
||||
*(
|
||||
jsonb_sql(v, scalars_to_jsonb=False)
|
||||
for v in value
|
||||
)
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
return scalar_convert(str(value))
|
||||
return value
|
||||
|
||||
|
||||
def has_index(column_or_constraint):
|
||||
"""
|
||||
Return whether or not given column or the columns of given foreign key
|
||||
constraint have an index. A column has an index if it has a single column
|
||||
index or it is the first column in compound column index.
|
||||
|
||||
A foreign key constraint has an index if the constraint columns are the
|
||||
first columns in compound column index.
|
||||
|
||||
:param column_or_constraint:
|
||||
SQLAlchemy Column object or SA ForeignKeyConstraint object
|
||||
|
||||
.. versionadded: 0.26.2
|
||||
|
||||
.. versionchanged: 0.30.18
|
||||
Added support for foreign key constaints.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_index
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
is_published = sa.Column(sa.Boolean, index=True)
|
||||
is_deleted = sa.Column(sa.Boolean)
|
||||
is_archived = sa.Column(sa.Boolean)
|
||||
|
||||
__table_args__ = (
|
||||
sa.Index('my_index', is_deleted, is_archived),
|
||||
)
|
||||
|
||||
|
||||
table = Article.__table__
|
||||
|
||||
has_index(table.c.is_published) # True
|
||||
has_index(table.c.is_deleted) # True
|
||||
has_index(table.c.is_archived) # False
|
||||
|
||||
|
||||
Also supports primary key indexes
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_index
|
||||
|
||||
|
||||
class ArticleTranslation(Base):
|
||||
__tablename__ = 'article_translation'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
locale = sa.Column(sa.String(10), primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
|
||||
|
||||
table = ArticleTranslation.__table__
|
||||
|
||||
has_index(table.c.locale) # False
|
||||
has_index(table.c.id) # True
|
||||
|
||||
|
||||
This function supports foreign key constraints as well
|
||||
|
||||
::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
author_last_name = sa.Column(sa.Unicode(255))
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
[author_first_name, author_last_name],
|
||||
[User.first_name, User.last_name]
|
||||
),
|
||||
sa.Index(
|
||||
'my_index',
|
||||
author_first_name,
|
||||
author_last_name
|
||||
)
|
||||
)
|
||||
|
||||
table = Article.__table__
|
||||
constraint = list(table.foreign_keys)[0].constraint
|
||||
|
||||
has_index(constraint) # True
|
||||
"""
|
||||
table = column_or_constraint.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
primary_keys = table.primary_key.columns.values()
|
||||
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
|
||||
columns = list(column_or_constraint.columns.values())
|
||||
else:
|
||||
columns = [column_or_constraint]
|
||||
|
||||
return (
|
||||
(primary_keys and starts_with(primary_keys, columns)) or
|
||||
any(
|
||||
starts_with(index.columns.values(), columns)
|
||||
for index in table.indexes
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def has_unique_index(column_or_constraint):
|
||||
"""
|
||||
Return whether or not given column or given foreign key constraint has a
|
||||
unique index.
|
||||
|
||||
A column has a unique index if it has a single column primary key index or
|
||||
it has a single column UniqueConstraint.
|
||||
|
||||
A foreign key constraint has a unique index if the columns of the
|
||||
constraint are the same as the columns of table primary key or the coluns
|
||||
of any unique index or any unique constraint of the given table.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
|
||||
.. versionadded: 0.27.1
|
||||
|
||||
.. versionchanged: 0.30.18
|
||||
Added support for foreign key constaints.
|
||||
|
||||
Fixed support for unique indexes (previously only worked for unique
|
||||
constraints)
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import has_unique_index
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(100))
|
||||
is_published = sa.Column(sa.Boolean, unique=True)
|
||||
is_deleted = sa.Column(sa.Boolean)
|
||||
is_archived = sa.Column(sa.Boolean)
|
||||
|
||||
|
||||
table = Article.__table__
|
||||
|
||||
has_unique_index(table.c.is_published) # True
|
||||
has_unique_index(table.c.is_deleted) # False
|
||||
has_unique_index(table.c.id) # True
|
||||
|
||||
|
||||
This function supports foreign key constraints as well
|
||||
|
||||
::
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
first_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
last_name = sa.Column(sa.Unicode(255), primary_key=True)
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_first_name = sa.Column(sa.Unicode(255))
|
||||
author_last_name = sa.Column(sa.Unicode(255))
|
||||
__table_args__ = (
|
||||
sa.ForeignKeyConstraint(
|
||||
[author_first_name, author_last_name],
|
||||
[User.first_name, User.last_name]
|
||||
),
|
||||
sa.Index(
|
||||
'my_index',
|
||||
author_first_name,
|
||||
author_last_name,
|
||||
unique=True
|
||||
)
|
||||
)
|
||||
|
||||
table = Article.__table__
|
||||
constraint = list(table.foreign_keys)[0].constraint
|
||||
|
||||
has_unique_index(constraint) # True
|
||||
|
||||
|
||||
:raises TypeError: if given column does not belong to a Table object
|
||||
"""
|
||||
table = column_or_constraint.table
|
||||
if not isinstance(table, sa.Table):
|
||||
raise TypeError(
|
||||
'Only columns belonging to Table objects are supported. Given '
|
||||
'column belongs to %r.' % table
|
||||
)
|
||||
primary_keys = list(table.primary_key.columns.values())
|
||||
if isinstance(column_or_constraint, sa.ForeignKeyConstraint):
|
||||
columns = list(column_or_constraint.columns.values())
|
||||
else:
|
||||
columns = [column_or_constraint]
|
||||
|
||||
return (
|
||||
(columns == primary_keys) or
|
||||
any(
|
||||
columns == list(constraint.columns.values())
|
||||
for constraint in table.constraints
|
||||
if isinstance(constraint, sa.sql.schema.UniqueConstraint)
|
||||
) or
|
||||
any(
|
||||
columns == list(index.columns.values())
|
||||
for index in table.indexes
|
||||
if index.unique
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def is_auto_assigned_date_column(column):
|
||||
"""
|
||||
Returns whether or not given SQLAlchemy Column object's is auto assigned
|
||||
DateTime or Date.
|
||||
|
||||
:param column: SQLAlchemy Column object
|
||||
"""
|
||||
return (
|
||||
(
|
||||
isinstance(column.type, sa.DateTime) or
|
||||
isinstance(column.type, sa.Date)
|
||||
) and
|
||||
(
|
||||
column.default or
|
||||
column.server_default or
|
||||
column.onupdate or
|
||||
column.server_onupdate
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _set_url_database(url: sa.engine.url.URL, database):
|
||||
"""Set the database of an engine URL.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
:param database: New database to set.
|
||||
|
||||
"""
|
||||
if hasattr(url, '_replace'):
|
||||
# Cannot use URL.set() as database may need to be set to None.
|
||||
ret = url._replace(database=database)
|
||||
else: # SQLAlchemy <1.4
|
||||
url = copy(url)
|
||||
url.database = database
|
||||
ret = url
|
||||
assert ret.database == database, ret
|
||||
return ret
|
||||
|
||||
|
||||
def _get_scalar_result(engine, sql):
|
||||
with engine.connect() as conn:
|
||||
return conn.scalar(sql)
|
||||
|
||||
|
||||
def _sqlite_file_exists(database):
|
||||
if not os.path.isfile(database) or os.path.getsize(database) < 100:
|
||||
return False
|
||||
|
||||
with open(database, 'rb') as f:
|
||||
header = f.read(100)
|
||||
|
||||
return header[:16] == b'SQLite format 3\x00'
|
||||
|
||||
|
||||
def database_exists(url):
|
||||
"""Check if a database exists.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
|
||||
Performs backend-specific testing to quickly determine if a database
|
||||
exists on the server. ::
|
||||
|
||||
database_exists('postgresql://postgres@localhost/name') #=> False
|
||||
create_database('postgresql://postgres@localhost/name')
|
||||
database_exists('postgresql://postgres@localhost/name') #=> True
|
||||
|
||||
Supports checking against a constructed URL as well. ::
|
||||
|
||||
engine = create_engine('postgresql://postgres@localhost/name')
|
||||
database_exists(engine.url) #=> False
|
||||
create_database(engine.url)
|
||||
database_exists(engine.url) #=> True
|
||||
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
engine = None
|
||||
try:
|
||||
if dialect_name == 'postgresql':
|
||||
text = "SELECT 1 FROM pg_database WHERE datname='%s'" % database
|
||||
for db in (database, 'postgres', 'template1', 'template0', None):
|
||||
url = _set_url_database(url, database=db)
|
||||
engine = sa.create_engine(url)
|
||||
try:
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
except (ProgrammingError, OperationalError):
|
||||
pass
|
||||
return False
|
||||
|
||||
elif dialect_name == 'mysql':
|
||||
url = _set_url_database(url, database=None)
|
||||
engine = sa.create_engine(url)
|
||||
text = ("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA "
|
||||
"WHERE SCHEMA_NAME = '%s'" % database)
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
|
||||
elif dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
engine = sa.create_engine(url)
|
||||
if database:
|
||||
return database == ':memory:' or _sqlite_file_exists(database)
|
||||
else:
|
||||
# The default SQLAlchemy database is in memory, and :memory: is
|
||||
# not required, thus we should support that use case.
|
||||
return True
|
||||
else:
|
||||
text = 'SELECT 1'
|
||||
try:
|
||||
engine = sa.create_engine(url)
|
||||
return bool(_get_scalar_result(engine, sa.text(text)))
|
||||
except (ProgrammingError, OperationalError):
|
||||
return False
|
||||
finally:
|
||||
if engine:
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def create_database(url, encoding='utf8', template=None):
|
||||
"""Issue the appropriate CREATE DATABASE statement.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
:param encoding: The encoding to create the database as.
|
||||
:param template:
|
||||
The name of the template from which to create the new database. At the
|
||||
moment only supported by PostgreSQL driver.
|
||||
|
||||
To create a database, you can pass a simple URL that would have
|
||||
been passed to ``create_engine``. ::
|
||||
|
||||
create_database('postgresql://postgres@localhost/name')
|
||||
|
||||
You may also pass the url from an existing engine. ::
|
||||
|
||||
create_database(engine.url)
|
||||
|
||||
Has full support for mysql, postgres, and sqlite. In theory,
|
||||
other database engines should be supported.
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
dialect_driver = url.get_dialect().driver
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
url = _set_url_database(url, database="postgres")
|
||||
elif dialect_name == 'mssql':
|
||||
url = _set_url_database(url, database="master")
|
||||
elif dialect_name == 'cockroachdb':
|
||||
url = _set_url_database(url, database="defaultdb")
|
||||
elif not dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
|
||||
if (dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}) \
|
||||
or (dialect_name == 'postgresql' and dialect_driver in {
|
||||
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}):
|
||||
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
|
||||
else:
|
||||
engine = sa.create_engine(url)
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
if not template:
|
||||
template = 'template1'
|
||||
|
||||
with engine.begin() as conn:
|
||||
text = "CREATE DATABASE {} ENCODING '{}' TEMPLATE {}".format(
|
||||
quote(conn, database),
|
||||
encoding,
|
||||
quote(conn, template)
|
||||
)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
elif dialect_name == 'mysql':
|
||||
with engine.begin() as conn:
|
||||
text = "CREATE DATABASE {} CHARACTER SET = '{}'".format(
|
||||
quote(conn, database),
|
||||
encoding
|
||||
)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
elif dialect_name == 'sqlite' and database != ':memory:':
|
||||
if database:
|
||||
with engine.begin() as conn:
|
||||
conn.execute(sa.text('CREATE TABLE DB(id int)'))
|
||||
conn.execute(sa.text('DROP TABLE DB'))
|
||||
|
||||
else:
|
||||
with engine.begin() as conn:
|
||||
text = f'CREATE DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def drop_database(url):
|
||||
"""Issue the appropriate DROP DATABASE statement.
|
||||
|
||||
:param url: A SQLAlchemy engine URL.
|
||||
|
||||
Works similar to the :ref:`create_database` method in that both url text
|
||||
and a constructed url are accepted. ::
|
||||
|
||||
drop_database('postgresql://postgres@localhost/name')
|
||||
drop_database(engine.url)
|
||||
|
||||
"""
|
||||
|
||||
url = make_url(url)
|
||||
database = url.database
|
||||
dialect_name = url.get_dialect().name
|
||||
dialect_driver = url.get_dialect().driver
|
||||
|
||||
if dialect_name == 'postgresql':
|
||||
url = _set_url_database(url, database="postgres")
|
||||
elif dialect_name == 'mssql':
|
||||
url = _set_url_database(url, database="master")
|
||||
elif dialect_name == 'cockroachdb':
|
||||
url = _set_url_database(url, database="defaultdb")
|
||||
elif not dialect_name == 'sqlite':
|
||||
url = _set_url_database(url, database=None)
|
||||
|
||||
if dialect_name == 'mssql' and dialect_driver in {'pymssql', 'pyodbc'}:
|
||||
engine = sa.create_engine(url, connect_args={'autocommit': True})
|
||||
elif dialect_name == 'postgresql' and dialect_driver in {
|
||||
'asyncpg', 'pg8000', 'psycopg', 'psycopg2', 'psycopg2cffi'}:
|
||||
engine = sa.create_engine(url, isolation_level='AUTOCOMMIT')
|
||||
else:
|
||||
engine = sa.create_engine(url)
|
||||
|
||||
if dialect_name == 'sqlite' and database != ':memory:':
|
||||
if database:
|
||||
os.remove(database)
|
||||
elif dialect_name == 'postgresql':
|
||||
with engine.begin() as conn:
|
||||
# Disconnect all users from the database we are dropping.
|
||||
version = conn.dialect.server_version_info
|
||||
pid_column = (
|
||||
'pid' if (version >= (9, 2)) else 'procpid'
|
||||
)
|
||||
text = '''
|
||||
SELECT pg_terminate_backend(pg_stat_activity.{pid_column})
|
||||
FROM pg_stat_activity
|
||||
WHERE pg_stat_activity.datname = '{database}'
|
||||
AND {pid_column} <> pg_backend_pid();
|
||||
'''.format(pid_column=pid_column, database=database)
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
# Drop the database.
|
||||
text = f'DROP DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
else:
|
||||
with engine.begin() as conn:
|
||||
text = f'DROP DATABASE {quote(conn, database)}'
|
||||
conn.execute(sa.text(text))
|
||||
|
||||
engine.dispose()
|
|
@ -0,0 +1,350 @@
|
|||
from collections import defaultdict
|
||||
from itertools import groupby
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.exc import NoInspectionAvailable
|
||||
from sqlalchemy.orm import object_session
|
||||
from sqlalchemy.schema import ForeignKeyConstraint, MetaData, Table
|
||||
|
||||
from ..query_chain import QueryChain
|
||||
from .database import has_index
|
||||
from .orm import _get_class_registry, get_column_key, get_mapper, get_tables
|
||||
|
||||
|
||||
def get_foreign_key_values(fk, obj):
|
||||
mapper = get_mapper(obj)
|
||||
return {
|
||||
fk.constraint.columns.values()[index]:
|
||||
getattr(obj, element.column.key)
|
||||
if hasattr(obj, element.column.key)
|
||||
else getattr(
|
||||
obj, mapper.get_property_by_column(element.column).key
|
||||
)
|
||||
for index, element in enumerate(fk.constraint.elements)
|
||||
}
|
||||
|
||||
|
||||
def group_foreign_keys(foreign_keys):
|
||||
"""
|
||||
Return a groupby iterator that groups given foreign keys by table.
|
||||
|
||||
:param foreign_keys: a sequence of foreign keys
|
||||
|
||||
|
||||
::
|
||||
|
||||
foreign_keys = get_referencing_foreign_keys(User)
|
||||
|
||||
for table, fks in group_foreign_keys(foreign_keys):
|
||||
# do something
|
||||
pass
|
||||
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
foreign_keys = sorted(
|
||||
foreign_keys, key=lambda key: key.constraint.table.name
|
||||
)
|
||||
return groupby(foreign_keys, lambda key: key.constraint.table)
|
||||
|
||||
|
||||
def get_referencing_foreign_keys(mixed):
|
||||
"""
|
||||
Returns referencing foreign keys for given Table object or declarative
|
||||
class.
|
||||
|
||||
:param mixed:
|
||||
SA Table object or SA declarative class
|
||||
|
||||
::
|
||||
|
||||
get_referencing_foreign_keys(User) # set([ForeignKey('user.id')])
|
||||
|
||||
get_referencing_foreign_keys(User.__table__)
|
||||
|
||||
|
||||
This function also understands inheritance. This means it returns
|
||||
all foreign keys that reference any table in the class inheritance tree.
|
||||
|
||||
Let's say you have three classes which use joined table inheritance,
|
||||
namely TextItem, Article and BlogPost with Article and BlogPost inheriting
|
||||
TextItem.
|
||||
|
||||
::
|
||||
|
||||
# This will check all foreign keys that reference either article table
|
||||
# or textitem table.
|
||||
get_referencing_foreign_keys(Article)
|
||||
|
||||
.. seealso:: :func:`get_tables`
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
tables = [mixed]
|
||||
else:
|
||||
tables = get_tables(mixed)
|
||||
|
||||
referencing_foreign_keys = set()
|
||||
|
||||
for table in mixed.metadata.tables.values():
|
||||
if table not in tables:
|
||||
for constraint in table.constraints:
|
||||
if isinstance(constraint, sa.sql.schema.ForeignKeyConstraint):
|
||||
for fk in constraint.elements:
|
||||
if any(fk.references(t) for t in tables):
|
||||
referencing_foreign_keys.add(fk)
|
||||
return referencing_foreign_keys
|
||||
|
||||
|
||||
def merge_references(from_, to, foreign_keys=None):
|
||||
"""
|
||||
Merge the references of an entity into another entity.
|
||||
|
||||
Consider the following models::
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(255))
|
||||
|
||||
def __repr__(self):
|
||||
return 'User(name=%r)' % self.name
|
||||
|
||||
class BlogPost(self.Base):
|
||||
__tablename__ = 'blog_post'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
title = sa.Column(sa.String(255))
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey('user.id'))
|
||||
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
|
||||
Now lets add some data::
|
||||
|
||||
john = self.User(name='John')
|
||||
jack = self.User(name='Jack')
|
||||
post = self.BlogPost(title='Some title', author=john)
|
||||
post2 = self.BlogPost(title='Other title', author=jack)
|
||||
self.session.add_all([
|
||||
john,
|
||||
jack,
|
||||
post,
|
||||
post2
|
||||
])
|
||||
self.session.commit()
|
||||
|
||||
|
||||
If we wanted to merge all John's references to Jack it would be as easy as
|
||||
::
|
||||
|
||||
merge_references(john, jack)
|
||||
self.session.commit()
|
||||
|
||||
post.author # User(name='Jack')
|
||||
post2.author # User(name='Jack')
|
||||
|
||||
|
||||
|
||||
:param from_: an entity to merge into another entity
|
||||
:param to: an entity to merge another entity into
|
||||
:param foreign_keys: A sequence of foreign keys. By default this is None
|
||||
indicating all referencing foreign keys should be used.
|
||||
|
||||
.. seealso: :func:`dependent_objects`
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
|
||||
.. versionchanged: 0.40.0
|
||||
|
||||
Removed possibility for old-style synchronize_session merging. Only
|
||||
SQL based merging supported for now.
|
||||
"""
|
||||
if from_.__tablename__ != to.__tablename__:
|
||||
raise TypeError('The tables of given arguments do not match.')
|
||||
|
||||
session = object_session(from_)
|
||||
foreign_keys = get_referencing_foreign_keys(from_)
|
||||
|
||||
for fk in foreign_keys:
|
||||
old_values = get_foreign_key_values(fk, from_)
|
||||
new_values = get_foreign_key_values(fk, to)
|
||||
criteria = (
|
||||
getattr(fk.constraint.table.c, key.key) == value
|
||||
for key, value in old_values.items()
|
||||
)
|
||||
query = (
|
||||
fk.constraint.table.update()
|
||||
.where(sa.and_(*criteria))
|
||||
.values(
|
||||
{key.key: value for key, value in new_values.items()}
|
||||
)
|
||||
)
|
||||
session.execute(query)
|
||||
|
||||
|
||||
def dependent_objects(obj, foreign_keys=None):
|
||||
"""
|
||||
Return a :class:`~sqlalchemy_utils.query_chain.QueryChain` that iterates
|
||||
through all dependent objects for given SQLAlchemy object.
|
||||
|
||||
Consider a User object is referenced in various articles and also in
|
||||
various orders. Getting all these dependent objects is as easy as::
|
||||
|
||||
from sqlalchemy_utils import dependent_objects
|
||||
|
||||
|
||||
dependent_objects(user)
|
||||
|
||||
|
||||
If you expect an object to have lots of dependent_objects it might be good
|
||||
to limit the results::
|
||||
|
||||
|
||||
dependent_objects(user).limit(5)
|
||||
|
||||
|
||||
|
||||
The common use case is checking for all restrict dependent objects before
|
||||
deleting parent object and inform the user if there are dependent objects
|
||||
with ondelete='RESTRICT' foreign keys. If this kind of checking is not used
|
||||
it will lead to nasty IntegrityErrors being raised.
|
||||
|
||||
In the following example we delete given user if it doesn't have any
|
||||
foreign key restricted dependent objects::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_referencing_foreign_keys
|
||||
|
||||
|
||||
user = session.query(User).get(some_user_id)
|
||||
|
||||
|
||||
deps = list(
|
||||
dependent_objects(
|
||||
user,
|
||||
(
|
||||
fk for fk in get_referencing_foreign_keys(User)
|
||||
# On most databases RESTRICT is the default mode hence we
|
||||
# check for None values also
|
||||
if fk.ondelete == 'RESTRICT' or fk.ondelete is None
|
||||
)
|
||||
).limit(5)
|
||||
)
|
||||
|
||||
if deps:
|
||||
# Do something to inform the user
|
||||
pass
|
||||
else:
|
||||
session.delete(user)
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param foreign_keys:
|
||||
A sequence of foreign keys to use for searching the dependent_objects
|
||||
for given object. By default this is None, indicating that all foreign
|
||||
keys referencing the object will be used.
|
||||
|
||||
.. note::
|
||||
This function does not support exotic mappers that use multiple tables
|
||||
|
||||
.. seealso:: :func:`get_referencing_foreign_keys`
|
||||
.. seealso:: :func:`merge_references`
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
if foreign_keys is None:
|
||||
foreign_keys = get_referencing_foreign_keys(obj)
|
||||
|
||||
session = object_session(obj)
|
||||
|
||||
chain = QueryChain([])
|
||||
classes = _get_class_registry(obj.__class__)
|
||||
|
||||
for table, keys in group_foreign_keys(foreign_keys):
|
||||
keys = list(keys)
|
||||
for class_ in classes.values():
|
||||
try:
|
||||
mapper = sa.inspect(class_)
|
||||
except NoInspectionAvailable:
|
||||
continue
|
||||
parent_mapper = mapper.inherits
|
||||
if (
|
||||
table in mapper.tables and
|
||||
not (parent_mapper and table in parent_mapper.tables)
|
||||
):
|
||||
query = session.query(class_).filter(
|
||||
sa.or_(*_get_criteria(keys, class_, obj))
|
||||
)
|
||||
chain.queries.append(query)
|
||||
return chain
|
||||
|
||||
|
||||
def _get_criteria(keys, class_, obj):
|
||||
criteria = []
|
||||
visited_constraints = []
|
||||
for key in keys:
|
||||
if key.constraint in visited_constraints:
|
||||
continue
|
||||
visited_constraints.append(key.constraint)
|
||||
|
||||
subcriteria = []
|
||||
for index, column in enumerate(key.constraint.columns):
|
||||
foreign_column = (
|
||||
key.constraint.elements[index].column
|
||||
)
|
||||
subcriteria.append(
|
||||
getattr(class_, get_column_key(class_, column)) ==
|
||||
getattr(
|
||||
obj,
|
||||
sa.inspect(type(obj))
|
||||
.get_property_by_column(
|
||||
foreign_column
|
||||
).key
|
||||
)
|
||||
)
|
||||
criteria.append(sa.and_(*subcriteria))
|
||||
return criteria
|
||||
|
||||
|
||||
def non_indexed_foreign_keys(metadata, engine=None):
|
||||
"""
|
||||
Finds all non indexed foreign keys from all tables of given MetaData.
|
||||
|
||||
Very useful for optimizing postgresql database and finding out which
|
||||
foreign keys need indexes.
|
||||
|
||||
:param metadata: MetaData object to inspect tables from
|
||||
"""
|
||||
reflected_metadata = MetaData()
|
||||
|
||||
bind = getattr(metadata, 'bind', None)
|
||||
if bind is None and engine is None:
|
||||
raise Exception(
|
||||
'Either pass a metadata object with bind or '
|
||||
'pass engine as a second parameter'
|
||||
)
|
||||
|
||||
constraints = defaultdict(list)
|
||||
|
||||
for table_name in metadata.tables.keys():
|
||||
table = Table(
|
||||
table_name,
|
||||
reflected_metadata,
|
||||
autoload_with=bind or engine
|
||||
)
|
||||
|
||||
for constraint in table.constraints:
|
||||
if not isinstance(constraint, ForeignKeyConstraint):
|
||||
continue
|
||||
|
||||
if not has_index(constraint):
|
||||
constraints[table.name].append(constraint)
|
||||
|
||||
return dict(constraints)
|
||||
|
||||
|
||||
def get_fk_constraint_for_columns(table, *columns):
|
||||
for constraint in table.constraints:
|
||||
if list(constraint.columns.values()) == list(columns):
|
||||
return constraint
|
|
@ -0,0 +1,112 @@
|
|||
import contextlib
|
||||
import datetime
|
||||
import inspect
|
||||
import io
|
||||
import re
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def create_mock_engine(bind, stream=None):
|
||||
"""Create a mock SQLAlchemy engine from the passed engine or bind URL.
|
||||
|
||||
:param bind: A SQLAlchemy engine or bind URL to mock.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
if not isinstance(bind, str):
|
||||
bind_url = str(bind.url)
|
||||
|
||||
else:
|
||||
bind_url = bind
|
||||
|
||||
if stream is not None:
|
||||
|
||||
def dump(sql, *args, **kwargs):
|
||||
|
||||
class Compiler(type(sql._compiler(engine.dialect))):
|
||||
|
||||
def visit_bindparam(self, bindparam, *args, **kwargs):
|
||||
return self.render_literal_value(
|
||||
bindparam.value, bindparam.type)
|
||||
|
||||
def render_literal_value(self, value, type_):
|
||||
if isinstance(value, int):
|
||||
return str(value)
|
||||
|
||||
elif isinstance(value, (datetime.date, datetime.datetime)):
|
||||
return "'%s'" % value
|
||||
|
||||
return super().render_literal_value(
|
||||
value, type_)
|
||||
|
||||
text = str(Compiler(engine.dialect, sql).process(sql))
|
||||
text = re.sub(r'\n+', '\n', text)
|
||||
text = text.strip('\n').strip()
|
||||
|
||||
stream.write('\n%s;' % text)
|
||||
|
||||
else:
|
||||
def dump(*args, **kw):
|
||||
return None
|
||||
|
||||
try:
|
||||
engine = sa.create_mock_engine(bind_url, executor=dump)
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
engine = sa.create_engine(bind_url, strategy='mock', executor=dump)
|
||||
return engine
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def mock_engine(engine, stream=None):
|
||||
"""Mocks out the engine specified in the passed bind expression.
|
||||
|
||||
Note this function is meant for convenience and protected usage. Do NOT
|
||||
blindly pass user input to this function as it uses exec.
|
||||
|
||||
:param engine: A python expression that represents the engine to mock.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
# Create a stream if not present.
|
||||
|
||||
if stream is None:
|
||||
stream = io.StringIO()
|
||||
|
||||
# Navigate the stack and find the calling frame that allows the
|
||||
# expression to execute.
|
||||
|
||||
for frame in inspect.stack()[1:]:
|
||||
|
||||
try:
|
||||
frame = frame[0]
|
||||
expression = '__target = %s' % engine
|
||||
exec(expression, frame.f_globals, frame.f_locals)
|
||||
target = frame.f_locals['__target']
|
||||
break
|
||||
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
else:
|
||||
|
||||
raise ValueError('Not a valid python expression', engine)
|
||||
|
||||
# Evaluate the expression and get the target engine.
|
||||
|
||||
frame.f_locals['__mock'] = create_mock_engine(target, stream)
|
||||
|
||||
# Replace the target with our mock.
|
||||
|
||||
exec('%s = __mock' % engine, frame.f_globals, frame.f_locals)
|
||||
|
||||
# Give control back.
|
||||
|
||||
yield stream
|
||||
|
||||
# Put the target engine back.
|
||||
|
||||
frame.f_locals['__target'] = target
|
||||
exec('%s = __target' % engine, frame.f_globals, frame.f_locals)
|
||||
exec('del __target', frame.f_globals, frame.f_locals)
|
||||
exec('del __mock', frame.f_globals, frame.f_locals)
|
|
@ -0,0 +1,904 @@
|
|||
from collections import OrderedDict
|
||||
from functools import partial
|
||||
from inspect import isclass
|
||||
from operator import attrgetter
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.engine.interfaces import Dialect
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import ColumnProperty, mapperlib, RelationshipProperty
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.orm.exc import UnmappedInstanceError
|
||||
|
||||
try:
|
||||
from sqlalchemy.orm.context import _ColumnEntity, _MapperEntity
|
||||
except ImportError: # SQLAlchemy <1.4
|
||||
from sqlalchemy.orm.query import _ColumnEntity, _MapperEntity
|
||||
|
||||
from sqlalchemy.orm.session import object_session
|
||||
from sqlalchemy.orm.util import AliasedInsp
|
||||
|
||||
from ..utils import is_sequence
|
||||
|
||||
|
||||
def get_class_by_table(base, table, data=None):
|
||||
"""
|
||||
Return declarative class associated with given table. If no class is found
|
||||
this function returns `None`. If multiple classes were found (polymorphic
|
||||
cases) additional `data` parameter can be given to hint which class
|
||||
to return.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
|
||||
get_class_by_table(Base, User.__table__) # User class
|
||||
|
||||
|
||||
This function also supports models using single table inheritance.
|
||||
Additional data paratemer should be provided in these case.
|
||||
|
||||
::
|
||||
|
||||
class Entity(Base):
|
||||
__tablename__ = 'entity'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
type = sa.Column(sa.String)
|
||||
__mapper_args__ = {
|
||||
'polymorphic_on': type,
|
||||
'polymorphic_identity': 'entity'
|
||||
}
|
||||
|
||||
class User(Entity):
|
||||
__mapper_args__ = {
|
||||
'polymorphic_identity': 'user'
|
||||
}
|
||||
|
||||
|
||||
# Entity class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'entity'})
|
||||
|
||||
# User class
|
||||
get_class_by_table(Base, Entity.__table__, {'type': 'user'})
|
||||
|
||||
|
||||
:param base: Declarative model base
|
||||
:param table: SQLAlchemy Table object
|
||||
:param data: Data row to determine the class in polymorphic scenarios
|
||||
:return: Declarative class or None.
|
||||
"""
|
||||
found_classes = {
|
||||
c for c in _get_class_registry(base).values()
|
||||
if hasattr(c, '__table__') and c.__table__ is table
|
||||
}
|
||||
if len(found_classes) > 1:
|
||||
if not data:
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{}'. "
|
||||
"Please provide data parameter for this function to be able "
|
||||
"to determine polymorphic scenarios.".format(
|
||||
table.name
|
||||
)
|
||||
)
|
||||
else:
|
||||
for cls in found_classes:
|
||||
mapper = sa.inspect(cls)
|
||||
polymorphic_on = mapper.polymorphic_on.name
|
||||
if polymorphic_on in data:
|
||||
if data[polymorphic_on] == mapper.polymorphic_identity:
|
||||
return cls
|
||||
raise ValueError(
|
||||
"Multiple declarative classes found for table '{}'. Given "
|
||||
"data row does not match any polymorphic identity of the "
|
||||
"found classes.".format(
|
||||
table.name
|
||||
)
|
||||
)
|
||||
elif found_classes:
|
||||
return found_classes.pop()
|
||||
return None
|
||||
|
||||
|
||||
def get_type(expr):
|
||||
"""
|
||||
Return the associated type with given Column, InstrumentedAttribute,
|
||||
ColumnProperty, RelationshipProperty or other similar SQLAlchemy construct.
|
||||
|
||||
For constructs wrapping columns this is the column type. For relationships
|
||||
this function returns the relationship mapper class.
|
||||
|
||||
:param expr:
|
||||
SQLAlchemy Column, InstrumentedAttribute, ColumnProperty or other
|
||||
similar SA construct.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
author_id = sa.Column(sa.Integer, sa.ForeignKey(User.id))
|
||||
author = sa.orm.relationship(User)
|
||||
|
||||
|
||||
get_type(User.__table__.c.name) # sa.String()
|
||||
get_type(User.name) # sa.String()
|
||||
get_type(User.name.property) # sa.String()
|
||||
|
||||
get_type(Article.author) # User
|
||||
|
||||
|
||||
.. versionadded: 0.30.9
|
||||
"""
|
||||
if hasattr(expr, 'type'):
|
||||
return expr.type
|
||||
elif isinstance(expr, InstrumentedAttribute):
|
||||
expr = expr.property
|
||||
|
||||
if isinstance(expr, ColumnProperty):
|
||||
return expr.columns[0].type
|
||||
elif isinstance(expr, RelationshipProperty):
|
||||
return expr.mapper.class_
|
||||
raise TypeError("Couldn't inspect type.")
|
||||
|
||||
|
||||
def cast_if(expression, type_):
|
||||
"""
|
||||
Produce a CAST expression but only if given expression is not of given type
|
||||
already.
|
||||
|
||||
Assume we have a model with two fields id (Integer) and name (String).
|
||||
|
||||
::
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import cast_if
|
||||
|
||||
|
||||
cast_if(User.id, sa.Integer) # "user".id
|
||||
cast_if(User.name, sa.String) # "user".name
|
||||
cast_if(User.id, sa.String) # CAST("user".id AS TEXT)
|
||||
|
||||
|
||||
This function supports scalar values as well.
|
||||
|
||||
::
|
||||
|
||||
cast_if(1, sa.Integer) # 1
|
||||
cast_if('text', sa.String) # 'text'
|
||||
cast_if(1, sa.String) # CAST(1 AS TEXT)
|
||||
|
||||
|
||||
:param expression:
|
||||
A SQL expression, such as a ColumnElement expression or a Python string
|
||||
which will be coerced into a bound literal value.
|
||||
:param type_:
|
||||
A TypeEngine class or instance indicating the type to which the CAST
|
||||
should apply.
|
||||
|
||||
.. versionadded: 0.30.14
|
||||
"""
|
||||
try:
|
||||
expr_type = get_type(expression)
|
||||
except TypeError:
|
||||
expr_type = expression
|
||||
check_type = type_().python_type
|
||||
else:
|
||||
check_type = type_
|
||||
|
||||
return (
|
||||
sa.cast(expression, type_)
|
||||
if not isinstance(expr_type, check_type)
|
||||
else expression
|
||||
)
|
||||
|
||||
|
||||
def get_column_key(model, column):
|
||||
"""
|
||||
Return the key for given column in given model.
|
||||
|
||||
:param model: SQLAlchemy declarative model object
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column('_name', sa.String)
|
||||
|
||||
|
||||
get_column_key(User, User.__table__.c._name) # 'name'
|
||||
|
||||
.. versionadded: 0.26.5
|
||||
|
||||
.. versionchanged: 0.27.11
|
||||
Throws UnmappedColumnError instead of ValueError when no property was
|
||||
found for given column. This is consistent with how SQLAlchemy works.
|
||||
"""
|
||||
mapper = sa.inspect(model)
|
||||
try:
|
||||
return mapper.get_property_by_column(column).key
|
||||
except sa.orm.exc.UnmappedColumnError:
|
||||
for key, c in mapper.columns.items():
|
||||
if c.name == column.name and c.table is column.table:
|
||||
return key
|
||||
raise sa.orm.exc.UnmappedColumnError(
|
||||
'No column %s is configured on mapper %s...' %
|
||||
(column, mapper)
|
||||
)
|
||||
|
||||
|
||||
def get_mapper(mixed):
|
||||
"""
|
||||
Return related SQLAlchemy Mapper for given SQLAlchemy object.
|
||||
|
||||
:param mixed: SQLAlchemy Table / Alias / Mapper / declarative model object
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import get_mapper
|
||||
|
||||
|
||||
get_mapper(User)
|
||||
|
||||
get_mapper(User())
|
||||
|
||||
get_mapper(User.__table__)
|
||||
|
||||
get_mapper(User.__mapper__)
|
||||
|
||||
get_mapper(sa.orm.aliased(User))
|
||||
|
||||
get_mapper(sa.orm.aliased(User.__table__))
|
||||
|
||||
|
||||
Raises:
|
||||
ValueError: if multiple mappers were found for given argument
|
||||
|
||||
.. versionadded: 0.26.1
|
||||
"""
|
||||
if isinstance(mixed, _MapperEntity):
|
||||
mixed = mixed.expr
|
||||
elif isinstance(mixed, sa.Column):
|
||||
mixed = mixed.table
|
||||
elif isinstance(mixed, _ColumnEntity):
|
||||
mixed = mixed.expr
|
||||
|
||||
if isinstance(mixed, sa.orm.Mapper):
|
||||
return mixed
|
||||
if isinstance(mixed, sa.orm.util.AliasedClass):
|
||||
return sa.inspect(mixed).mapper
|
||||
if isinstance(mixed, sa.sql.selectable.Alias):
|
||||
mixed = mixed.element
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.mapper
|
||||
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
|
||||
mixed = mixed.class_
|
||||
if isinstance(mixed, sa.Table):
|
||||
if hasattr(mapperlib, '_all_registries'):
|
||||
all_mappers = set()
|
||||
for mapper_registry in mapperlib._all_registries():
|
||||
all_mappers.update(mapper_registry.mappers)
|
||||
else: # SQLAlchemy <1.4
|
||||
all_mappers = mapperlib._mapper_registry
|
||||
mappers = [
|
||||
mapper for mapper in all_mappers
|
||||
if mixed in mapper.tables
|
||||
]
|
||||
if len(mappers) > 1:
|
||||
raise ValueError(
|
||||
"Multiple mappers found for table '%s'." % mixed.name
|
||||
)
|
||||
elif not mappers:
|
||||
raise ValueError(
|
||||
"Could not get mapper for table '%s'." % mixed.name
|
||||
)
|
||||
else:
|
||||
return mappers[0]
|
||||
if not isclass(mixed):
|
||||
mixed = type(mixed)
|
||||
return sa.inspect(mixed)
|
||||
|
||||
|
||||
def get_bind(obj):
|
||||
"""
|
||||
Return the bind for given SQLAlchemy Engine / Connection / declarative
|
||||
model object.
|
||||
|
||||
:param obj: SQLAlchemy Engine / Connection / declarative model object
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import get_bind
|
||||
|
||||
|
||||
get_bind(session) # Connection object
|
||||
|
||||
get_bind(user)
|
||||
|
||||
"""
|
||||
if hasattr(obj, 'bind'):
|
||||
conn = obj.bind
|
||||
else:
|
||||
try:
|
||||
conn = object_session(obj).bind
|
||||
except UnmappedInstanceError:
|
||||
conn = obj
|
||||
|
||||
if not hasattr(conn, 'execute'):
|
||||
raise TypeError(
|
||||
'This method accepts only Session, Engine, Connection and '
|
||||
'declarative model objects.'
|
||||
)
|
||||
return conn
|
||||
|
||||
|
||||
def get_primary_keys(mixed):
|
||||
"""
|
||||
Return an OrderedDict of all primary keys for given Table object,
|
||||
declarative class or declarative class instance.
|
||||
|
||||
:param mixed:
|
||||
SA Table object, SA declarative class or SA declarative class instance
|
||||
|
||||
::
|
||||
|
||||
get_primary_keys(User)
|
||||
|
||||
get_primary_keys(User())
|
||||
|
||||
get_primary_keys(User.__table__)
|
||||
|
||||
get_primary_keys(User.__mapper__)
|
||||
|
||||
get_primary_keys(sa.orm.aliased(User))
|
||||
|
||||
get_primary_keys(sa.orm.aliased(User.__table__))
|
||||
|
||||
|
||||
.. versionchanged: 0.25.3
|
||||
Made the function return an ordered dictionary instead of generator.
|
||||
This change was made to support primary key aliases.
|
||||
|
||||
Renamed this function to 'get_primary_keys', formerly 'primary_keys'
|
||||
|
||||
.. seealso:: :func:`get_columns`
|
||||
"""
|
||||
return OrderedDict(
|
||||
(
|
||||
(key, column) for key, column in get_columns(mixed).items()
|
||||
if column.primary_key
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_tables(mixed):
|
||||
"""
|
||||
Return a set of tables associated with given SQLAlchemy object.
|
||||
|
||||
Let's say we have three classes which use joined table inheritance
|
||||
TextItem, Article and BlogPost. Article and BlogPost inherit TextItem.
|
||||
|
||||
::
|
||||
|
||||
get_tables(Article) # set([Table('article', ...), Table('text_item')])
|
||||
|
||||
get_tables(Article())
|
||||
|
||||
get_tables(Article.__mapper__)
|
||||
|
||||
|
||||
If the TextItem entity is using with_polymorphic='*' then this function
|
||||
returns all child tables (article and blog_post) as well.
|
||||
|
||||
::
|
||||
|
||||
|
||||
get_tables(TextItem) # set([Table('text_item', ...)], ...])
|
||||
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
|
||||
:param mixed:
|
||||
SQLAlchemy Mapper, Declarative class, Column, InstrumentedAttribute or
|
||||
a SA Alias object wrapping any of these objects.
|
||||
"""
|
||||
if isinstance(mixed, sa.Table):
|
||||
return [mixed]
|
||||
elif isinstance(mixed, sa.Column):
|
||||
return [mixed.table]
|
||||
elif isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
|
||||
return mixed.parent.tables
|
||||
elif isinstance(mixed, _ColumnEntity):
|
||||
mixed = mixed.expr
|
||||
|
||||
mapper = get_mapper(mixed)
|
||||
|
||||
polymorphic_mappers = get_polymorphic_mappers(mapper)
|
||||
if polymorphic_mappers:
|
||||
tables = sum((m.tables for m in polymorphic_mappers), [])
|
||||
else:
|
||||
tables = mapper.tables
|
||||
return tables
|
||||
|
||||
|
||||
def get_columns(mixed):
|
||||
"""
|
||||
Return a collection of all Column objects for given SQLAlchemy
|
||||
object.
|
||||
|
||||
The type of the collection depends on the type of the object to return the
|
||||
columns from.
|
||||
|
||||
::
|
||||
|
||||
get_columns(User)
|
||||
|
||||
get_columns(User())
|
||||
|
||||
get_columns(User.__table__)
|
||||
|
||||
get_columns(User.__mapper__)
|
||||
|
||||
get_columns(sa.orm.aliased(User))
|
||||
|
||||
get_columns(sa.orm.alised(User.__table__))
|
||||
|
||||
|
||||
:param mixed:
|
||||
SA Table object, SA Mapper, SA declarative class, SA declarative class
|
||||
instance or an alias of any of these objects
|
||||
"""
|
||||
if isinstance(mixed, sa.sql.selectable.Selectable):
|
||||
try:
|
||||
return mixed.selected_columns
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
return mixed.c
|
||||
if isinstance(mixed, sa.orm.util.AliasedClass):
|
||||
return sa.inspect(mixed).mapper.columns
|
||||
if isinstance(mixed, sa.orm.Mapper):
|
||||
return mixed.columns
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return mixed.property.columns
|
||||
if isinstance(mixed, ColumnProperty):
|
||||
return mixed.columns
|
||||
if isinstance(mixed, sa.Column):
|
||||
return [mixed]
|
||||
if not isclass(mixed):
|
||||
mixed = mixed.__class__
|
||||
return sa.inspect(mixed).columns
|
||||
|
||||
|
||||
def table_name(obj):
|
||||
"""
|
||||
Return table name of given target, declarative class or the
|
||||
table name where the declarative attribute is bound to.
|
||||
"""
|
||||
class_ = getattr(obj, 'class_', obj)
|
||||
|
||||
try:
|
||||
return class_.__tablename__
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
try:
|
||||
return class_.__table__.name
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def getattrs(obj, attrs):
|
||||
return map(partial(getattr, obj), attrs)
|
||||
|
||||
|
||||
def quote(mixed, ident):
|
||||
"""
|
||||
Conditionally quote an identifier.
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import quote
|
||||
|
||||
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
|
||||
quote(engine, 'order')
|
||||
# '"order"'
|
||||
|
||||
quote(engine, 'some_other_identifier')
|
||||
# 'some_other_identifier'
|
||||
|
||||
|
||||
:param mixed: SQLAlchemy Session / Connection / Engine / Dialect object.
|
||||
:param ident: identifier to conditionally quote
|
||||
"""
|
||||
if isinstance(mixed, Dialect):
|
||||
dialect = mixed
|
||||
else:
|
||||
dialect = get_bind(mixed).dialect
|
||||
return dialect.preparer(dialect).quote(ident)
|
||||
|
||||
|
||||
def _get_query_compile_state(query):
|
||||
if hasattr(query, '_compile_state'):
|
||||
return query._compile_state()
|
||||
else: # SQLAlchemy <1.4
|
||||
return query
|
||||
|
||||
|
||||
def get_polymorphic_mappers(mixed):
|
||||
if isinstance(mixed, AliasedInsp):
|
||||
return mixed.with_polymorphic_mappers
|
||||
else:
|
||||
return mixed.polymorphic_map.values()
|
||||
|
||||
|
||||
def get_descriptor(entity, attr):
|
||||
mapper = sa.inspect(entity)
|
||||
|
||||
for key, descriptor in get_all_descriptors(mapper).items():
|
||||
if attr == key:
|
||||
prop = (
|
||||
descriptor.property
|
||||
if hasattr(descriptor, 'property')
|
||||
else None
|
||||
)
|
||||
if isinstance(prop, ColumnProperty):
|
||||
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||
for c in mapper.selectable.c:
|
||||
if c.key == attr:
|
||||
return c
|
||||
else:
|
||||
# If the property belongs to a class that uses
|
||||
# polymorphic inheritance we have to take into account
|
||||
# situations where the attribute exists in child class
|
||||
# but not in parent class.
|
||||
return getattr(prop.parent.class_, attr)
|
||||
else:
|
||||
# Handle synonyms, relationship properties and hybrid
|
||||
# properties
|
||||
|
||||
if isinstance(entity, sa.orm.util.AliasedClass):
|
||||
return getattr(entity, attr)
|
||||
try:
|
||||
return getattr(mapper.class_, attr)
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
|
||||
def get_all_descriptors(expr):
|
||||
if isinstance(expr, sa.sql.selectable.Selectable):
|
||||
return expr.c
|
||||
insp = sa.inspect(expr)
|
||||
try:
|
||||
polymorphic_mappers = get_polymorphic_mappers(insp)
|
||||
except sa.exc.NoInspectionAvailable:
|
||||
return get_mapper(expr).all_orm_descriptors
|
||||
else:
|
||||
attrs = dict(get_mapper(expr).all_orm_descriptors)
|
||||
for submapper in polymorphic_mappers:
|
||||
for key, descriptor in submapper.all_orm_descriptors.items():
|
||||
if key not in attrs:
|
||||
attrs[key] = descriptor
|
||||
return attrs
|
||||
|
||||
|
||||
def get_hybrid_properties(model):
|
||||
"""
|
||||
Returns a dictionary of hybrid property keys and hybrid properties for
|
||||
given SQLAlchemy declarative model / mapper.
|
||||
|
||||
|
||||
Consider the following model
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
|
||||
@hybrid_property
|
||||
def lowercase_name(self):
|
||||
return self.name.lower()
|
||||
|
||||
@lowercase_name.expression
|
||||
def lowercase_name(cls):
|
||||
return sa.func.lower(cls.name)
|
||||
|
||||
|
||||
You can now easily get a list of all hybrid property names
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import get_hybrid_properties
|
||||
|
||||
|
||||
get_hybrid_properties(Category).keys() # ['lowercase_name']
|
||||
|
||||
|
||||
This function also supports aliased classes
|
||||
|
||||
::
|
||||
|
||||
|
||||
get_hybrid_properties(
|
||||
sa.orm.aliased(Category)
|
||||
).keys() # ['lowercase_name']
|
||||
|
||||
|
||||
.. versionchanged: 0.26.7
|
||||
This function now returns a dictionary instead of generator
|
||||
|
||||
.. versionchanged: 0.30.15
|
||||
Added support for aliased classes
|
||||
|
||||
:param model: SQLAlchemy declarative model or mapper
|
||||
"""
|
||||
return {
|
||||
key: prop
|
||||
for key, prop in get_mapper(model).all_orm_descriptors.items()
|
||||
if isinstance(prop, hybrid_property)
|
||||
}
|
||||
|
||||
|
||||
def get_declarative_base(model):
|
||||
"""
|
||||
Returns the declarative base for given model class.
|
||||
|
||||
:param model: SQLAlchemy declarative model
|
||||
"""
|
||||
for parent in model.__bases__:
|
||||
try:
|
||||
parent.metadata
|
||||
return get_declarative_base(parent)
|
||||
except AttributeError:
|
||||
pass
|
||||
return model
|
||||
|
||||
|
||||
def getdotattr(obj_or_class, dot_path, condition=None):
|
||||
"""
|
||||
Allow dot-notated strings to be passed to `getattr`.
|
||||
|
||||
::
|
||||
|
||||
getdotattr(SubSection, 'section.document')
|
||||
|
||||
getdotattr(subsection, 'section.document')
|
||||
|
||||
|
||||
:param obj_or_class: Any object or class
|
||||
:param dot_path: Attribute path with dot mark as separator
|
||||
"""
|
||||
last = obj_or_class
|
||||
|
||||
for path in str(dot_path).split('.'):
|
||||
getter = attrgetter(path)
|
||||
|
||||
if is_sequence(last):
|
||||
tmp = []
|
||||
for element in last:
|
||||
value = getter(element)
|
||||
if is_sequence(value):
|
||||
tmp.extend(value)
|
||||
else:
|
||||
tmp.append(value)
|
||||
last = tmp
|
||||
elif isinstance(last, InstrumentedAttribute):
|
||||
last = getter(last.property.mapper.class_)
|
||||
elif last is None:
|
||||
return None
|
||||
else:
|
||||
last = getter(last)
|
||||
if condition is not None:
|
||||
if is_sequence(last):
|
||||
last = [v for v in last if condition(v)]
|
||||
else:
|
||||
if not condition(last):
|
||||
return None
|
||||
|
||||
return last
|
||||
|
||||
|
||||
def is_deleted(obj):
|
||||
return obj in sa.orm.object_session(obj).deleted
|
||||
|
||||
|
||||
def has_changes(obj, attrs=None, exclude=None):
|
||||
"""
|
||||
Simple shortcut function for checking if given attributes of given
|
||||
declarative model object have changed during the session. Without
|
||||
parameters this checks if given object has any modificiations. Additionally
|
||||
exclude parameter can be given to check if given object has any changes
|
||||
in any attributes other than the ones given in exclude.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import has_changes
|
||||
|
||||
|
||||
user = User()
|
||||
|
||||
has_changes(user, 'name') # False
|
||||
|
||||
user.name = 'someone'
|
||||
|
||||
has_changes(user, 'name') # True
|
||||
|
||||
has_changes(user) # True
|
||||
|
||||
|
||||
You can check multiple attributes as well.
|
||||
::
|
||||
|
||||
|
||||
has_changes(user, ['age']) # True
|
||||
|
||||
has_changes(user, ['name', 'age']) # True
|
||||
|
||||
|
||||
This function also supports excluding certain attributes.
|
||||
|
||||
::
|
||||
|
||||
has_changes(user, exclude=['name']) # False
|
||||
|
||||
has_changes(user, exclude=['age']) # True
|
||||
|
||||
.. versionchanged: 0.26.6
|
||||
Added support for multiple attributes and exclude parameter.
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param attrs: Names of the attributes
|
||||
:param exclude: Names of the attributes to exclude
|
||||
"""
|
||||
if attrs:
|
||||
if isinstance(attrs, str):
|
||||
return (
|
||||
sa.inspect(obj)
|
||||
.attrs
|
||||
.get(attrs)
|
||||
.history
|
||||
.has_changes()
|
||||
)
|
||||
else:
|
||||
return any(has_changes(obj, attr) for attr in attrs)
|
||||
else:
|
||||
if exclude is None:
|
||||
exclude = []
|
||||
return any(
|
||||
attr.history.has_changes()
|
||||
for key, attr in sa.inspect(obj).attrs.items()
|
||||
if key not in exclude
|
||||
)
|
||||
|
||||
|
||||
def is_loaded(obj, prop):
|
||||
"""
|
||||
Return whether or not given property of given object has been loaded.
|
||||
|
||||
::
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
content = sa.orm.deferred(sa.Column(sa.String))
|
||||
|
||||
|
||||
article = session.query(Article).get(5)
|
||||
|
||||
# name gets loaded since its not a deferred property
|
||||
assert is_loaded(article, 'name')
|
||||
|
||||
# content has not yet been loaded since its a deferred property
|
||||
assert not is_loaded(article, 'content')
|
||||
|
||||
|
||||
.. versionadded: 0.27.8
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param prop: Name of the property or InstrumentedAttribute
|
||||
"""
|
||||
return prop not in sa.inspect(obj).unloaded
|
||||
|
||||
|
||||
def identity(obj_or_class):
|
||||
"""
|
||||
Return the identity of given sqlalchemy declarative model class or instance
|
||||
as a tuple. This differs from obj._sa_instance_state.identity in a way that
|
||||
it always returns the identity even if object is still in transient state (
|
||||
new object that is not yet persisted into database). Also for classes it
|
||||
returns the identity attributes.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy_utils import identity
|
||||
|
||||
|
||||
user = User(name='John Matrix')
|
||||
session.add(user)
|
||||
identity(user) # None
|
||||
inspect(user).identity # None
|
||||
|
||||
session.flush() # User now has id but is still in transient state
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # None
|
||||
|
||||
session.commit()
|
||||
|
||||
identity(user) # (1,)
|
||||
inspect(user).identity # (1, )
|
||||
|
||||
|
||||
You can also use identity for classes::
|
||||
|
||||
|
||||
identity(User) # (User.id, )
|
||||
|
||||
.. versionadded: 0.21.0
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
"""
|
||||
return tuple(
|
||||
getattr(obj_or_class, column_key)
|
||||
for column_key in get_primary_keys(obj_or_class).keys()
|
||||
)
|
||||
|
||||
|
||||
def naturally_equivalent(obj, obj2):
|
||||
"""
|
||||
Returns whether or not two given SQLAlchemy declarative instances are
|
||||
naturally equivalent (all their non primary key properties are equivalent).
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import naturally_equivalent
|
||||
|
||||
|
||||
user = User(name='someone')
|
||||
user2 = User(name='someone')
|
||||
|
||||
user == user2 # False
|
||||
|
||||
naturally_equivalent(user, user2) # True
|
||||
|
||||
|
||||
:param obj: SQLAlchemy declarative model object
|
||||
:param obj2: SQLAlchemy declarative model object to compare with `obj`
|
||||
"""
|
||||
for column_key, column in sa.inspect(obj.__class__).columns.items():
|
||||
if column.primary_key:
|
||||
continue
|
||||
|
||||
if not (getattr(obj, column_key) == getattr(obj2, column_key)):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _get_class_registry(class_):
|
||||
try:
|
||||
return class_.registry._class_registry
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
return class_._decl_class_registry
|
|
@ -0,0 +1,75 @@
|
|||
import inspect
|
||||
import io
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .mock import create_mock_engine
|
||||
from .orm import _get_query_compile_state
|
||||
|
||||
|
||||
def render_expression(expression, bind, stream=None):
|
||||
"""Generate a SQL expression from the passed python expression.
|
||||
|
||||
Only the global variable, `engine`, is available for use in the
|
||||
expression. Additional local variables may be passed in the context
|
||||
parameter.
|
||||
|
||||
Note this function is meant for convenience and protected usage. Do NOT
|
||||
blindly pass user input to this function as it uses exec.
|
||||
|
||||
:param bind: A SQLAlchemy engine or bind URL.
|
||||
:param stream: Render all DDL operations to the stream.
|
||||
"""
|
||||
|
||||
# Create a stream if not present.
|
||||
|
||||
if stream is None:
|
||||
stream = io.StringIO()
|
||||
|
||||
engine = create_mock_engine(bind, stream)
|
||||
|
||||
# Navigate the stack and find the calling frame that allows the
|
||||
# expression to execuate.
|
||||
|
||||
for frame in inspect.stack()[1:]:
|
||||
try:
|
||||
frame = frame[0]
|
||||
local = dict(frame.f_locals)
|
||||
local['engine'] = engine
|
||||
exec(expression, frame.f_globals, local)
|
||||
break
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
raise ValueError('Not a valid python expression', engine)
|
||||
|
||||
return stream
|
||||
|
||||
|
||||
def render_statement(statement, bind=None):
|
||||
"""
|
||||
Generate an SQL expression string with bound parameters rendered inline
|
||||
for the given SQLAlchemy statement.
|
||||
|
||||
:param statement: SQLAlchemy Query object.
|
||||
:param bind:
|
||||
Optional SQLAlchemy bind, if None uses the bind of the given query
|
||||
object.
|
||||
"""
|
||||
|
||||
if isinstance(statement, sa.orm.query.Query):
|
||||
if bind is None:
|
||||
bind = statement.session.get_bind(
|
||||
_get_query_compile_state(statement)._mapper_zero()
|
||||
)
|
||||
|
||||
statement = statement.statement
|
||||
|
||||
elif bind is None:
|
||||
bind = statement.bind
|
||||
|
||||
stream = io.StringIO()
|
||||
engine = create_mock_engine(bind.engine, stream=stream)
|
||||
engine.execute(statement)
|
||||
|
||||
return stream.getvalue()
|
|
@ -0,0 +1,74 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
from .database import has_unique_index
|
||||
from .orm import _get_query_compile_state, get_tables
|
||||
|
||||
|
||||
def make_order_by_deterministic(query):
|
||||
"""
|
||||
Make query order by deterministic (if it isn't already). Order by is
|
||||
considered deterministic if it contains column that is unique index (
|
||||
either it is a primary key or has a unique index). Many times it is design
|
||||
flaw to order by queries in nondeterministic manner.
|
||||
|
||||
Consider a User model with three fields: id (primary key), favorite color
|
||||
and email (unique).::
|
||||
|
||||
|
||||
from sqlalchemy_utils import make_order_by_deterministic
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.favorite_color)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".favorite_color, "user".id'
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.email)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".email'
|
||||
|
||||
|
||||
query = session.query(User).order_by(User.id)
|
||||
|
||||
query = make_order_by_deterministic(query)
|
||||
print query # 'SELECT ... ORDER BY "user".id'
|
||||
|
||||
|
||||
.. versionadded: 0.27.1
|
||||
"""
|
||||
order_by_func = sa.asc
|
||||
|
||||
try:
|
||||
order_by_clauses = query._order_by_clauses
|
||||
except AttributeError: # SQLAlchemy <1.4
|
||||
order_by_clauses = query._order_by
|
||||
if not order_by_clauses:
|
||||
column = None
|
||||
else:
|
||||
order_by = order_by_clauses[0]
|
||||
if isinstance(order_by, sa.sql.elements._label_reference):
|
||||
order_by = order_by.element
|
||||
if isinstance(order_by, sa.sql.expression.UnaryExpression):
|
||||
if order_by.modifier == sa.sql.operators.desc_op:
|
||||
order_by_func = sa.desc
|
||||
else:
|
||||
order_by_func = sa.asc
|
||||
column = list(order_by.get_children())[0]
|
||||
else:
|
||||
column = order_by
|
||||
|
||||
# Skip queries that are ordered by an already deterministic column
|
||||
if isinstance(column, sa.Column):
|
||||
try:
|
||||
if has_unique_index(column):
|
||||
return query
|
||||
except TypeError:
|
||||
pass
|
||||
|
||||
base_table = get_tables(_get_query_compile_state(query)._entities[0])[0]
|
||||
query = query.order_by(
|
||||
*(order_by_func(c) for c in base_table.c if c.primary_key)
|
||||
)
|
||||
return query
|
|
@ -0,0 +1,185 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
|
||||
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
||||
from sqlalchemy.orm.session import _state_session
|
||||
from sqlalchemy.util import set_creation_order
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
from .functions import identity
|
||||
from .functions.orm import _get_class_registry
|
||||
|
||||
|
||||
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
||||
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
||||
if self.key in dict_:
|
||||
return dict_[self.key]
|
||||
|
||||
# Retrieve the session bound to the state in order to perform
|
||||
# a lazy query for the attribute.
|
||||
session = _state_session(state)
|
||||
if session is None:
|
||||
# State is not bound to a session; we cannot proceed.
|
||||
return None
|
||||
|
||||
# Find class for discriminator.
|
||||
# TODO: Perhaps optimize with some sort of lookup?
|
||||
discriminator = self.get_state_discriminator(state)
|
||||
target_class = _get_class_registry(state.class_).get(discriminator)
|
||||
|
||||
if target_class is None:
|
||||
# Unknown discriminator; return nothing.
|
||||
return None
|
||||
|
||||
id = self.get_state_id(state)
|
||||
|
||||
try:
|
||||
target = session.get(target_class, id)
|
||||
except AttributeError:
|
||||
# sqlalchemy 1.3
|
||||
target = session.query(target_class).get(id)
|
||||
|
||||
# Return found (or not found) target.
|
||||
return target
|
||||
|
||||
def get_state_discriminator(self, state):
|
||||
discriminator = self.parent_token.discriminator
|
||||
if isinstance(discriminator, hybrid_property):
|
||||
return getattr(state.obj(), discriminator.__name__)
|
||||
else:
|
||||
return state.attrs[discriminator.key].value
|
||||
|
||||
def get_state_id(self, state):
|
||||
# Lookup row with the discriminator and id.
|
||||
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
||||
|
||||
def set(self, state, dict_, initiator,
|
||||
passive=attributes.PASSIVE_OFF,
|
||||
check_old=None,
|
||||
pop=False):
|
||||
|
||||
# Set us on the state.
|
||||
dict_[self.key] = initiator
|
||||
|
||||
if initiator is None:
|
||||
# Nullify relationship args
|
||||
for id in self.parent_token.id:
|
||||
dict_[id.key] = None
|
||||
dict_[self.parent_token.discriminator.key] = None
|
||||
else:
|
||||
# Get the primary key of the initiator and ensure we
|
||||
# can support this assignment.
|
||||
class_ = type(initiator)
|
||||
mapper = class_mapper(class_)
|
||||
|
||||
pk = mapper.identity_key_from_instance(initiator)[1]
|
||||
|
||||
# Set the identifier and the discriminator.
|
||||
discriminator = class_.__name__
|
||||
|
||||
for index, id in enumerate(self.parent_token.id):
|
||||
dict_[id.key] = pk[index]
|
||||
dict_[self.parent_token.discriminator.key] = discriminator
|
||||
|
||||
|
||||
class GenericRelationshipProperty(MapperProperty):
|
||||
"""A generic form of the relationship property.
|
||||
|
||||
Creates a 1 to many relationship between the parent model
|
||||
and any other models using a descriminator (the table name).
|
||||
|
||||
:param discriminator
|
||||
Field to discriminate which model we are referring to.
|
||||
:param id:
|
||||
Field to point to the model we are referring to.
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator, id, doc=None):
|
||||
super().__init__()
|
||||
self._discriminator_col = discriminator
|
||||
self._id_cols = id
|
||||
self._id = None
|
||||
self._discriminator = None
|
||||
self.doc = doc
|
||||
|
||||
set_creation_order(self)
|
||||
|
||||
def _column_to_property(self, column):
|
||||
if isinstance(column, hybrid_property):
|
||||
attr_key = column.__name__
|
||||
for key, attr in self.parent.all_orm_descriptors.items():
|
||||
if key == attr_key:
|
||||
return attr
|
||||
else:
|
||||
for attr in self.parent.attrs.values():
|
||||
if isinstance(attr, ColumnProperty):
|
||||
if attr.columns[0].name == column.name:
|
||||
return attr
|
||||
|
||||
def init(self):
|
||||
def convert_strings(column):
|
||||
if isinstance(column, str):
|
||||
return self.parent.columns[column]
|
||||
return column
|
||||
|
||||
self._discriminator_col = convert_strings(self._discriminator_col)
|
||||
self._id_cols = convert_strings(self._id_cols)
|
||||
|
||||
if isinstance(self._id_cols, Iterable):
|
||||
self._id_cols = list(map(convert_strings, self._id_cols))
|
||||
else:
|
||||
self._id_cols = [self._id_cols]
|
||||
|
||||
self.discriminator = self._column_to_property(self._discriminator_col)
|
||||
|
||||
if self.discriminator is None:
|
||||
raise ImproperlyConfigured(
|
||||
'Could not find discriminator descriptor.'
|
||||
)
|
||||
|
||||
self.id = list(map(self._column_to_property, self._id_cols))
|
||||
|
||||
class Comparator(PropComparator):
|
||||
def __init__(self, prop, parentmapper):
|
||||
self.property = prop
|
||||
self._parententity = parentmapper
|
||||
|
||||
def __eq__(self, other):
|
||||
discriminator = type(other).__name__
|
||||
q = self.property._discriminator_col == discriminator
|
||||
other_id = identity(other)
|
||||
for index, id in enumerate(self.property._id_cols):
|
||||
q &= id == other_id[index]
|
||||
return q
|
||||
|
||||
def __ne__(self, other):
|
||||
return ~(self == other)
|
||||
|
||||
def is_type(self, other):
|
||||
mapper = sa.inspect(other)
|
||||
# Iterate through the weak sequence in order to get the actual
|
||||
# mappers
|
||||
class_names = [other.__name__]
|
||||
class_names.extend([
|
||||
submapper.class_.__name__
|
||||
for submapper in mapper._inheriting_mappers
|
||||
])
|
||||
|
||||
return self.property._discriminator_col.in_(class_names)
|
||||
|
||||
def instrument_class(self, mapper):
|
||||
attributes.register_attribute(
|
||||
mapper.class_,
|
||||
self.key,
|
||||
comparator=self.Comparator(self, mapper),
|
||||
parententity=mapper,
|
||||
doc=self.doc,
|
||||
impl_class=GenericAttributeImpl,
|
||||
parent_token=self
|
||||
)
|
||||
|
||||
|
||||
def generic_relationship(*args, **kwargs):
|
||||
return GenericRelationshipProperty(*args, **kwargs)
|
119
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/i18n.py
Normal file
119
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/i18n.py
Normal file
|
@ -0,0 +1,119 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.ext.hybrid import hybrid_property
|
||||
from sqlalchemy.sql.expression import ColumnElement
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
|
||||
try:
|
||||
import babel
|
||||
import babel.dates
|
||||
except ImportError:
|
||||
babel = None
|
||||
|
||||
|
||||
def get_locale():
|
||||
try:
|
||||
return babel.Locale('en')
|
||||
except AttributeError:
|
||||
# As babel is optional, we may raise an AttributeError accessing it
|
||||
raise ImproperlyConfigured(
|
||||
'Could not load get_locale function using Babel. Either '
|
||||
'install Babel or make a similar function and override it '
|
||||
'in this module.'
|
||||
)
|
||||
|
||||
|
||||
def cast_locale(obj, locale, attr):
|
||||
"""
|
||||
Cast given locale to string. Supports also callbacks that return locales.
|
||||
|
||||
:param obj:
|
||||
Object or class to use as a possible parameter to locale callable
|
||||
:param locale:
|
||||
Locale object or string or callable that returns a locale.
|
||||
"""
|
||||
if callable(locale):
|
||||
try:
|
||||
locale = locale(obj, attr.key)
|
||||
except TypeError:
|
||||
try:
|
||||
locale = locale(obj)
|
||||
except TypeError:
|
||||
locale = locale()
|
||||
if isinstance(locale, babel.Locale):
|
||||
return str(locale)
|
||||
return locale
|
||||
|
||||
|
||||
class cast_locale_expr(ColumnElement):
|
||||
inherit_cache = False
|
||||
|
||||
def __init__(self, cls, locale, attr):
|
||||
self.cls = cls
|
||||
self.locale = locale
|
||||
self.attr = attr
|
||||
|
||||
|
||||
@compiles(cast_locale_expr)
|
||||
def compile_cast_locale_expr(element, compiler, **kw):
|
||||
locale = cast_locale(element.cls, element.locale, element.attr)
|
||||
if isinstance(locale, str):
|
||||
return f"'{locale}'"
|
||||
return compiler.process(locale)
|
||||
|
||||
|
||||
class TranslationHybrid:
|
||||
def __init__(self, current_locale, default_locale, default_value=None):
|
||||
if babel is None:
|
||||
raise ImproperlyConfigured(
|
||||
'You need to install babel in order to use TranslationHybrid.'
|
||||
)
|
||||
self.current_locale = current_locale
|
||||
self.default_locale = default_locale
|
||||
self.default_value = default_value
|
||||
|
||||
def getter_factory(self, attr):
|
||||
"""
|
||||
Return a hybrid_property getter function for given attribute. The
|
||||
returned getter first checks if object has translation for current
|
||||
locale. If not it tries to get translation for default locale. If there
|
||||
is no translation found for default locale it returns None.
|
||||
"""
|
||||
def getter(obj):
|
||||
current_locale = cast_locale(obj, self.current_locale, attr)
|
||||
try:
|
||||
return getattr(obj, attr.key)[current_locale]
|
||||
except (TypeError, KeyError):
|
||||
default_locale = cast_locale(obj, self.default_locale, attr)
|
||||
try:
|
||||
return getattr(obj, attr.key)[default_locale]
|
||||
except (TypeError, KeyError):
|
||||
return self.default_value
|
||||
return getter
|
||||
|
||||
def setter_factory(self, attr):
|
||||
def setter(obj, value):
|
||||
if getattr(obj, attr.key) is None:
|
||||
setattr(obj, attr.key, {})
|
||||
locale = cast_locale(obj, self.current_locale, attr)
|
||||
getattr(obj, attr.key)[locale] = value
|
||||
return setter
|
||||
|
||||
def expr_factory(self, attr):
|
||||
def expr(cls):
|
||||
cls_attr = getattr(cls, attr.key)
|
||||
current_locale = cast_locale_expr(cls, self.current_locale, attr)
|
||||
default_locale = cast_locale_expr(cls, self.default_locale, attr)
|
||||
return sa.func.coalesce(
|
||||
cls_attr[current_locale],
|
||||
cls_attr[default_locale]
|
||||
)
|
||||
return expr
|
||||
|
||||
def __call__(self, attr):
|
||||
return hybrid_property(
|
||||
fget=self.getter_factory(attr),
|
||||
fset=self.setter_factory(attr),
|
||||
expr=self.expr_factory(attr)
|
||||
)
|
|
@ -0,0 +1,277 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
from .exceptions import ImproperlyConfigured
|
||||
|
||||
|
||||
def coercion_listener(mapper, class_):
|
||||
"""
|
||||
Auto assigns coercing listener for all class properties which are of coerce
|
||||
capable type.
|
||||
"""
|
||||
for prop in mapper.iterate_properties:
|
||||
try:
|
||||
listener = prop.columns[0].type.coercion_listener
|
||||
except AttributeError:
|
||||
continue
|
||||
sa.event.listen(
|
||||
getattr(class_, prop.key),
|
||||
'set',
|
||||
listener,
|
||||
retval=True
|
||||
)
|
||||
|
||||
|
||||
def instant_defaults_listener(target, args, kwargs):
|
||||
# insertion order of kwargs matters
|
||||
# copy and clear so that we can add back later at the end of the dict
|
||||
original = kwargs.copy()
|
||||
kwargs.clear()
|
||||
|
||||
for key, column in sa.inspect(target.__class__).columns.items():
|
||||
if (
|
||||
hasattr(column, 'default') and
|
||||
column.default is not None
|
||||
):
|
||||
if callable(column.default.arg):
|
||||
kwargs[key] = column.default.arg(target)
|
||||
else:
|
||||
kwargs[key] = column.default.arg
|
||||
|
||||
# supersede w/initial in case target uses setters overriding defaults
|
||||
kwargs.update(original)
|
||||
|
||||
|
||||
def force_auto_coercion(mapper=None):
|
||||
"""
|
||||
Function that assigns automatic data type coercion for all classes which
|
||||
are of type of given mapper. The coercion is applied to all coercion
|
||||
capable properties. By default coercion is applied to all SQLAlchemy
|
||||
mappers.
|
||||
|
||||
Before initializing your models you need to call force_auto_coercion.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import force_auto_coercion
|
||||
|
||||
|
||||
force_auto_coercion()
|
||||
|
||||
|
||||
Then define your models the usual way::
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(50))
|
||||
background_color = sa.Column(ColorType)
|
||||
|
||||
|
||||
Now scalar values for coercion capable data types will convert to
|
||||
appropriate value objects::
|
||||
|
||||
document = Document()
|
||||
document.background_color = 'F5F5F5'
|
||||
document.background_color # Color object
|
||||
session.commit()
|
||||
|
||||
A useful side effect of this is that additional validation of data will be
|
||||
done on the moment it is being assigned to model objects. For example
|
||||
without autocorrection set, an invalid
|
||||
:class:`sqlalchemy_utils.types.IPAddressType` (eg. ``10.0.0 255.255``)
|
||||
would get through without an exception being raised. The database wouldn't
|
||||
notice this (as most databases don't have a native type for an IP address,
|
||||
so they're usually just stored as a string), and the ``ipaddress``
|
||||
package uses a string field as well.
|
||||
|
||||
:param mapper: The mapper which the automatic data type coercion should be
|
||||
applied to
|
||||
"""
|
||||
if mapper is None:
|
||||
mapper = sa.orm.Mapper
|
||||
sa.event.listen(mapper, 'mapper_configured', coercion_listener)
|
||||
|
||||
|
||||
def force_instant_defaults(mapper=None):
|
||||
"""
|
||||
Function that assigns object column defaults on object initialization
|
||||
time. By default calling this function applies instant defaults to all
|
||||
your models.
|
||||
|
||||
Setting up instant defaults::
|
||||
|
||||
|
||||
from sqlalchemy_utils import force_instant_defaults
|
||||
|
||||
|
||||
force_instant_defaults()
|
||||
|
||||
Example usage::
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(50))
|
||||
created_at = sa.Column(sa.DateTime, default=datetime.now)
|
||||
|
||||
|
||||
document = Document()
|
||||
document.created_at # datetime object
|
||||
|
||||
|
||||
:param mapper: The mapper which the automatic instant defaults forcing
|
||||
should be applied to
|
||||
"""
|
||||
if mapper is None:
|
||||
mapper = sa.orm.Mapper
|
||||
sa.event.listen(mapper, 'init', instant_defaults_listener)
|
||||
|
||||
|
||||
def auto_delete_orphans(attr):
|
||||
"""
|
||||
Delete orphans for given SQLAlchemy model attribute. This function can be
|
||||
used for deleting many-to-many associated orphans easily. For more
|
||||
information see
|
||||
https://bitbucket.org/zzzeek/sqlalchemy/wiki/UsageRecipes/ManyToManyOrphan.
|
||||
|
||||
Consider the following model definition:
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy.ext.associationproxy import association_proxy
|
||||
from sqlalchemy import *
|
||||
from sqlalchemy.orm import *
|
||||
# Necessary in sqlalchemy 1.3:
|
||||
# from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy import event
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
tagging = Table(
|
||||
'tagging',
|
||||
Base.metadata,
|
||||
Column(
|
||||
'tag_id',
|
||||
Integer,
|
||||
ForeignKey('tag.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
),
|
||||
Column(
|
||||
'entry_id',
|
||||
Integer,
|
||||
ForeignKey('entry.id', ondelete='CASCADE'),
|
||||
primary_key=True
|
||||
)
|
||||
)
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = 'tag'
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String(100), unique=True, nullable=False)
|
||||
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
class Entry(Base):
|
||||
__tablename__ = 'entry'
|
||||
|
||||
id = Column(Integer, primary_key=True)
|
||||
|
||||
tags = relationship(
|
||||
'Tag',
|
||||
secondary=tagging,
|
||||
backref='entries'
|
||||
)
|
||||
|
||||
Now lets say we want to delete the tags if all their parents get deleted (
|
||||
all Entry objects get deleted). This can be achieved as follows:
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import auto_delete_orphans
|
||||
|
||||
|
||||
auto_delete_orphans(Entry.tags)
|
||||
|
||||
|
||||
After we've set up this listener we can see it in action.
|
||||
|
||||
::
|
||||
|
||||
|
||||
e = create_engine('sqlite://')
|
||||
|
||||
Base.metadata.create_all(e)
|
||||
|
||||
s = Session(e)
|
||||
|
||||
r1 = Entry()
|
||||
r2 = Entry()
|
||||
r3 = Entry()
|
||||
t1, t2, t3, t4 = Tag('t1'), Tag('t2'), Tag('t3'), Tag('t4')
|
||||
|
||||
r1.tags.extend([t1, t2])
|
||||
r2.tags.extend([t2, t3])
|
||||
r3.tags.extend([t4])
|
||||
s.add_all([r1, r2, r3])
|
||||
|
||||
assert s.query(Tag).count() == 4
|
||||
|
||||
r2.tags.remove(t2)
|
||||
|
||||
assert s.query(Tag).count() == 4
|
||||
|
||||
r1.tags.remove(t2)
|
||||
|
||||
assert s.query(Tag).count() == 3
|
||||
|
||||
r1.tags.remove(t1)
|
||||
|
||||
assert s.query(Tag).count() == 2
|
||||
|
||||
.. versionadded: 0.26.4
|
||||
|
||||
:param attr: Association relationship attribute to auto delete orphans from
|
||||
"""
|
||||
|
||||
parent_class = attr.parent.class_
|
||||
target_class = attr.property.mapper.class_
|
||||
|
||||
backref = attr.property.backref
|
||||
if not backref:
|
||||
raise ImproperlyConfigured(
|
||||
'The relationship argument given for auto_delete_orphans needs to '
|
||||
'have a backref relationship set.'
|
||||
)
|
||||
if isinstance(backref, tuple):
|
||||
backref = backref[0]
|
||||
|
||||
@sa.event.listens_for(sa.orm.Session, 'after_flush')
|
||||
def delete_orphan_listener(session, ctx):
|
||||
# Look through Session state to see if we want to emit a DELETE for
|
||||
# orphans
|
||||
orphans_found = (
|
||||
any(
|
||||
isinstance(obj, parent_class) and
|
||||
sa.orm.attributes.get_history(obj, attr.key).deleted
|
||||
for obj in session.dirty
|
||||
) or
|
||||
any(
|
||||
isinstance(obj, parent_class)
|
||||
for obj in session.deleted
|
||||
)
|
||||
)
|
||||
|
||||
if orphans_found:
|
||||
# Emit a DELETE for all orphans
|
||||
(
|
||||
session.query(target_class)
|
||||
.filter(
|
||||
~getattr(target_class, backref).any()
|
||||
)
|
||||
.delete(synchronize_session=False)
|
||||
)
|
|
@ -0,0 +1,96 @@
|
|||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class Timestamp:
|
||||
"""Adds `created` and `updated` columns to a derived declarative model.
|
||||
|
||||
The `created` column is handled through a default and the `updated`
|
||||
column is handled through a `before_update` event that propagates
|
||||
for all derived declarative models.
|
||||
|
||||
::
|
||||
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import Timestamp
|
||||
|
||||
|
||||
class SomeModel(Base, Timestamp):
|
||||
__tablename__ = 'somemodel'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
"""
|
||||
|
||||
created = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False)
|
||||
updated = sa.Column(sa.DateTime, default=datetime.utcnow, nullable=False)
|
||||
|
||||
|
||||
@sa.event.listens_for(Timestamp, 'before_update', propagate=True)
|
||||
def timestamp_before_update(mapper, connection, target):
|
||||
# When a model with a timestamp is updated; force update the updated
|
||||
# timestamp.
|
||||
target.updated = datetime.utcnow()
|
||||
|
||||
|
||||
NOT_LOADED_REPR = '<not loaded>'
|
||||
|
||||
|
||||
def _generic_repr_method(self, fields):
|
||||
state = sa.inspect(self)
|
||||
field_reprs = []
|
||||
if not fields:
|
||||
fields = state.mapper.columns.keys()
|
||||
for key in fields:
|
||||
value = state.attrs[key].loaded_value
|
||||
if key in state.unloaded:
|
||||
value = NOT_LOADED_REPR
|
||||
else:
|
||||
value = repr(value)
|
||||
field_reprs.append('='.join((key, value)))
|
||||
|
||||
return '{}({})'.format(self.__class__.__name__, ', '.join(field_reprs))
|
||||
|
||||
|
||||
def generic_repr(*fields):
|
||||
"""Adds generic ``__repr__()`` method to a declarative SQLAlchemy model.
|
||||
|
||||
In case if some fields are not loaded from a database, it doesn't
|
||||
force their loading and instead repesents them as ``<not loaded>``.
|
||||
|
||||
In addition, user can provide field names as arguments to the decorator
|
||||
to specify what fields should present in the string representation
|
||||
and in what order.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import generic_repr
|
||||
|
||||
|
||||
@generic_repr
|
||||
class MyModel(Base):
|
||||
__tablename__ = 'mymodel'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
category = sa.Column(sa.String)
|
||||
|
||||
session.add(MyModel(name='Foo', category='Bar'))
|
||||
session.commit()
|
||||
foo = session.query(MyModel).options(sa.orm.defer('category')).one(s)
|
||||
|
||||
assert repr(foo) == 'MyModel(id=1, name='Foo', category=<not loaded>)'
|
||||
"""
|
||||
if len(fields) == 1 and callable(fields[0]):
|
||||
target = fields[0]
|
||||
target.__repr__ = lambda self: _generic_repr_method(self, fields=None)
|
||||
return target
|
||||
else:
|
||||
def decorator(cls):
|
||||
cls.__repr__ = lambda self: _generic_repr_method(
|
||||
self,
|
||||
fields=fields
|
||||
)
|
||||
return cls
|
||||
return decorator
|
|
@ -0,0 +1,376 @@
|
|||
"""
|
||||
This module provides a decorator function for observing changes in a given
|
||||
property. Internally the decorator is implemented using SQLAlchemy event
|
||||
listeners. Both column properties and relationship properties can be observed.
|
||||
|
||||
Property observers can be used for pre-calculating aggregates and automatic
|
||||
real-time data denormalization.
|
||||
|
||||
Simple observers
|
||||
----------------
|
||||
|
||||
At the heart of the observer extension is the :func:`observes` decorator. You
|
||||
mark some property path as being observed and the marked method will get
|
||||
notified when any changes are made to given path.
|
||||
|
||||
Consider the following model structure:
|
||||
|
||||
::
|
||||
|
||||
class Director(Base):
|
||||
__tablename__ = 'director'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
date_of_birth = sa.Column(sa.Date)
|
||||
|
||||
class Movie(Base):
|
||||
__tablename__ = 'movie'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String)
|
||||
director_id = sa.Column(sa.Integer, sa.ForeignKey(Director.id))
|
||||
director = sa.orm.relationship(Director, backref='movies')
|
||||
|
||||
|
||||
Now consider we want to show movies in some listing ordered by director id
|
||||
first and movie id secondly. If we have many movies then using joins and
|
||||
ordering by Director.name will be very slow. Here is where denormalization
|
||||
and :func:`observes` comes to rescue the day. Let's add a new column called
|
||||
director_name to Movie which will get automatically copied from associated
|
||||
Director.
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import observes
|
||||
|
||||
|
||||
class Movie(Base):
|
||||
# same as before..
|
||||
director_name = sa.Column(sa.String)
|
||||
|
||||
@observes('director')
|
||||
def director_observer(self, director):
|
||||
self.director_name = director.name
|
||||
|
||||
.. note::
|
||||
|
||||
This example could be done much more efficiently using a compound foreign
|
||||
key from director_name, director_id to Director.name, Director.id but for
|
||||
the sake of simplicity we added this as an example.
|
||||
|
||||
|
||||
Observes vs aggregated
|
||||
----------------------
|
||||
|
||||
:func:`observes` and :func:`.aggregates.aggregated` can be used for similar
|
||||
things. However performance wise you should take the following things into
|
||||
consideration:
|
||||
|
||||
* :func:`observes` works always inside transaction and deals with objects. If
|
||||
the relationship observer is observing has a large number of objects it's
|
||||
better to use :func:`.aggregates.aggregated`.
|
||||
* :func:`.aggregates.aggregated` always executes one additional query per
|
||||
aggregate so in scenarios where the observed relationship has only a handful
|
||||
of objects it's better to use :func:`observes` instead.
|
||||
|
||||
|
||||
Example 1. Movie with many ratings
|
||||
|
||||
Let's say we have a Movie object with potentially thousands of ratings. In this
|
||||
case we should always use :func:`.aggregates.aggregated` since iterating
|
||||
through thousands of objects is slow and very memory consuming.
|
||||
|
||||
Example 2. Product with denormalized catalog name
|
||||
|
||||
Each product belongs to one catalog. Here it is natural to use :func:`observes`
|
||||
for data denormalization.
|
||||
|
||||
|
||||
Deeply nested observing
|
||||
-----------------------
|
||||
|
||||
Consider the following model structure where Catalog has many Categories and
|
||||
Category has many Products.
|
||||
|
||||
::
|
||||
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
product_count = sa.Column(sa.Integer, default=0)
|
||||
|
||||
@observes('categories.products')
|
||||
def product_observer(self, products):
|
||||
self.product_count = len(products)
|
||||
|
||||
categories = sa.orm.relationship('Category', backref='catalog')
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
products = sa.orm.relationship('Product', backref='category')
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
price = sa.Column(sa.Numeric)
|
||||
|
||||
category_id = sa.Column(sa.Integer, sa.ForeignKey('category.id'))
|
||||
|
||||
|
||||
:func:`observes` is smart enough to:
|
||||
|
||||
* Notify catalog objects of any changes in associated Product objects
|
||||
* Notify catalog objects of any changes in Category objects that affect
|
||||
products (for example if Category gets deleted, or a new Category is added to
|
||||
Catalog with any number of Products)
|
||||
|
||||
|
||||
::
|
||||
|
||||
category = Category(
|
||||
products=[Product(), Product()]
|
||||
)
|
||||
category2 = Category(
|
||||
product=[Product()]
|
||||
)
|
||||
|
||||
catalog = Catalog(
|
||||
categories=[category, category2]
|
||||
)
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
catalog.product_count # 2
|
||||
|
||||
session.delete(category)
|
||||
session.commit()
|
||||
catalog.product_count # 1
|
||||
|
||||
|
||||
Observing multiple columns
|
||||
-----------------------
|
||||
|
||||
You can also observe multiple columns by specifying all the observable columns
|
||||
in the decorator.
|
||||
|
||||
|
||||
::
|
||||
|
||||
class Order(Base):
|
||||
__tablename__ = 'order'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
unit_price = sa.Column(sa.Integer)
|
||||
amount = sa.Column(sa.Integer)
|
||||
total_price = sa.Column(sa.Integer)
|
||||
|
||||
@observes('amount', 'unit_price')
|
||||
def total_price_observer(self, amount, unit_price):
|
||||
self.total_price = amount * unit_price
|
||||
|
||||
|
||||
|
||||
"""
|
||||
import itertools
|
||||
from collections import defaultdict, namedtuple
|
||||
from collections.abc import Iterable
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from .functions import getdotattr, has_changes
|
||||
from .path import AttrPath
|
||||
from .utils import is_sequence
|
||||
|
||||
Callback = namedtuple('Callback', ['func', 'backref', 'fullpath'])
|
||||
|
||||
|
||||
class PropertyObserver:
|
||||
def __init__(self):
|
||||
self.listener_args = [
|
||||
(
|
||||
sa.orm.Mapper,
|
||||
'mapper_configured',
|
||||
self.update_generator_registry
|
||||
),
|
||||
(
|
||||
sa.orm.Mapper,
|
||||
'after_configured',
|
||||
self.gather_paths
|
||||
),
|
||||
(
|
||||
sa.orm.session.Session,
|
||||
'before_flush',
|
||||
self.invoke_callbacks
|
||||
)
|
||||
]
|
||||
self.callback_map = defaultdict(list)
|
||||
# TODO: make the registry a WeakKey dict
|
||||
self.generator_registry = defaultdict(list)
|
||||
|
||||
def remove_listeners(self):
|
||||
for args in self.listener_args:
|
||||
sa.event.remove(*args)
|
||||
|
||||
def register_listeners(self):
|
||||
for args in self.listener_args:
|
||||
if not sa.event.contains(*args):
|
||||
sa.event.listen(*args)
|
||||
|
||||
def __repr__(self):
|
||||
return '<PropertyObserver>'
|
||||
|
||||
def update_generator_registry(self, mapper, class_):
|
||||
"""
|
||||
Adds generator functions to generator_registry.
|
||||
"""
|
||||
|
||||
for generator in class_.__dict__.values():
|
||||
if hasattr(generator, '__observes__'):
|
||||
self.generator_registry[class_].append(
|
||||
generator
|
||||
)
|
||||
|
||||
def gather_paths(self):
|
||||
for class_, generators in self.generator_registry.items():
|
||||
for callback in generators:
|
||||
full_paths = []
|
||||
for call_path in callback.__observes__:
|
||||
full_paths.append(AttrPath(class_, call_path))
|
||||
|
||||
for path in full_paths:
|
||||
self.callback_map[class_].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
backref=None,
|
||||
fullpath=full_paths
|
||||
)
|
||||
)
|
||||
|
||||
for index in range(len(path)):
|
||||
i = index + 1
|
||||
prop = path[index].property
|
||||
if isinstance(prop, sa.orm.RelationshipProperty):
|
||||
prop_class = path[index].property.mapper.class_
|
||||
self.callback_map[prop_class].append(
|
||||
Callback(
|
||||
func=callback,
|
||||
backref=~ (path[:i]),
|
||||
fullpath=full_paths
|
||||
)
|
||||
)
|
||||
|
||||
def gather_callback_args(self, obj, callbacks):
|
||||
session = sa.orm.object_session(obj)
|
||||
for callback in callbacks:
|
||||
backref = callback.backref
|
||||
|
||||
root_objs = getdotattr(obj, backref) if backref else obj
|
||||
if root_objs:
|
||||
if not isinstance(root_objs, Iterable):
|
||||
root_objs = [root_objs]
|
||||
|
||||
with session.no_autoflush:
|
||||
for root_obj in root_objs:
|
||||
if root_obj:
|
||||
args = self.get_callback_args(root_obj, callback)
|
||||
if args:
|
||||
yield args
|
||||
|
||||
def get_callback_args(self, root_obj, callback):
|
||||
session = sa.orm.object_session(root_obj)
|
||||
objects = [getdotattr(
|
||||
root_obj,
|
||||
path,
|
||||
lambda obj: obj not in session.deleted
|
||||
) for path in callback.fullpath]
|
||||
paths = [str(path) for path in callback.fullpath]
|
||||
for path in paths:
|
||||
if '.' in path or has_changes(root_obj, path):
|
||||
return (
|
||||
root_obj,
|
||||
callback.func,
|
||||
objects
|
||||
)
|
||||
|
||||
def iterate_objects_and_callbacks(self, session):
|
||||
objs = itertools.chain(session.new, session.dirty, session.deleted)
|
||||
for obj in objs:
|
||||
for class_, callbacks in self.callback_map.items():
|
||||
if isinstance(obj, class_):
|
||||
yield obj, callbacks
|
||||
|
||||
def invoke_callbacks(self, session, ctx, instances):
|
||||
callback_args = defaultdict(lambda: defaultdict(set))
|
||||
for obj, callbacks in self.iterate_objects_and_callbacks(session):
|
||||
args = self.gather_callback_args(obj, callbacks)
|
||||
for (root_obj, func, objects) in args:
|
||||
if not callback_args[root_obj][func]:
|
||||
callback_args[root_obj][func] = {}
|
||||
for i, object_ in enumerate(objects):
|
||||
if is_sequence(object_):
|
||||
callback_args[root_obj][func][i] = (
|
||||
callback_args[root_obj][func].get(i, set()) |
|
||||
set(object_)
|
||||
)
|
||||
else:
|
||||
callback_args[root_obj][func][i] = object_
|
||||
|
||||
for root_obj, callback_objs in callback_args.items():
|
||||
for callback, objs in callback_objs.items():
|
||||
callback(root_obj, *[objs[i] for i in range(len(objs))])
|
||||
|
||||
|
||||
observer = PropertyObserver()
|
||||
|
||||
|
||||
def observes(*paths, **observer_kw):
|
||||
"""
|
||||
Mark method as property observer for the given property path. Inside
|
||||
transaction observer gathers all changes made in given property path and
|
||||
feeds the changed objects to observer-marked method at the before flush
|
||||
phase.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import observes
|
||||
|
||||
|
||||
class Catalog(Base):
|
||||
__tablename__ = 'catalog'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
category_count = sa.Column(sa.Integer, default=0)
|
||||
|
||||
@observes('categories')
|
||||
def category_observer(self, categories):
|
||||
self.category_count = len(categories)
|
||||
|
||||
class Category(Base):
|
||||
__tablename__ = 'category'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
catalog_id = sa.Column(sa.Integer, sa.ForeignKey('catalog.id'))
|
||||
|
||||
|
||||
catalog = Catalog(categories=[Category(), Category()])
|
||||
session.add(catalog)
|
||||
session.commit()
|
||||
|
||||
catalog.category_count # 2
|
||||
|
||||
|
||||
.. versionadded: 0.28.0
|
||||
|
||||
:param *paths: One or more dot-notated property paths, eg.
|
||||
'categories.products.price'
|
||||
:param **observer: A dictionary where value for key 'observer' contains
|
||||
:meth:`PropertyObserver` object
|
||||
"""
|
||||
observer_ = observer_kw.pop('observer', observer)
|
||||
observer_.register_listeners()
|
||||
|
||||
def wraps(func):
|
||||
def wrapper(self, *args, **kwargs):
|
||||
return func(self, *args, **kwargs)
|
||||
wrapper.__observes__ = paths
|
||||
return wrapper
|
||||
return wraps
|
|
@ -0,0 +1,74 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
def inspect_type(mixed):
|
||||
if isinstance(mixed, sa.orm.attributes.InstrumentedAttribute):
|
||||
return mixed.property.columns[0].type
|
||||
elif isinstance(mixed, sa.orm.ColumnProperty):
|
||||
return mixed.columns[0].type
|
||||
elif isinstance(mixed, sa.Column):
|
||||
return mixed.type
|
||||
|
||||
|
||||
def is_case_insensitive(mixed):
|
||||
try:
|
||||
return isinstance(
|
||||
inspect_type(mixed).comparator,
|
||||
CaseInsensitiveComparator
|
||||
)
|
||||
except AttributeError:
|
||||
try:
|
||||
return issubclass(
|
||||
inspect_type(mixed).comparator_factory,
|
||||
CaseInsensitiveComparator
|
||||
)
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
|
||||
class CaseInsensitiveComparator(sa.Unicode.Comparator):
|
||||
@classmethod
|
||||
def lowercase_arg(cls, func):
|
||||
def operation(self, other, **kwargs):
|
||||
operator = getattr(sa.Unicode.Comparator, func)
|
||||
if other is None:
|
||||
return operator(self, other, **kwargs)
|
||||
if not is_case_insensitive(other):
|
||||
other = sa.func.lower(other)
|
||||
return operator(self, other, **kwargs)
|
||||
return operation
|
||||
|
||||
def in_(self, other):
|
||||
if isinstance(other, list) or isinstance(other, tuple):
|
||||
other = map(sa.func.lower, other)
|
||||
return sa.Unicode.Comparator.in_(self, other)
|
||||
|
||||
def notin_(self, other):
|
||||
if isinstance(other, list) or isinstance(other, tuple):
|
||||
other = map(sa.func.lower, other)
|
||||
return sa.Unicode.Comparator.notin_(self, other)
|
||||
|
||||
|
||||
string_operator_funcs = [
|
||||
'__eq__',
|
||||
'__ne__',
|
||||
'__lt__',
|
||||
'__le__',
|
||||
'__gt__',
|
||||
'__ge__',
|
||||
'concat',
|
||||
'contains',
|
||||
'ilike',
|
||||
'like',
|
||||
'notlike',
|
||||
'notilike',
|
||||
'startswith',
|
||||
'endswith',
|
||||
]
|
||||
|
||||
for func in string_operator_funcs:
|
||||
setattr(
|
||||
CaseInsensitiveComparator,
|
||||
func,
|
||||
CaseInsensitiveComparator.lowercase_arg(func)
|
||||
)
|
152
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/path.py
Normal file
152
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/path.py
Normal file
|
@ -0,0 +1,152 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.orm.attributes import InstrumentedAttribute
|
||||
from sqlalchemy.util.langhelpers import symbol
|
||||
|
||||
from .utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
class Path:
|
||||
def __init__(self, path, separator='.'):
|
||||
if isinstance(path, Path):
|
||||
self.path = path.path
|
||||
else:
|
||||
self.path = path
|
||||
self.separator = separator
|
||||
|
||||
@property
|
||||
def parts(self):
|
||||
return self.path.split(self.separator)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.parts
|
||||
|
||||
def __len__(self):
|
||||
return len(self.parts)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}('{self.path}')"
|
||||
|
||||
def index(self, element):
|
||||
return self.parts.index(element)
|
||||
|
||||
def __getitem__(self, slice):
|
||||
result = self.parts[slice]
|
||||
if isinstance(result, list):
|
||||
return self.__class__(
|
||||
self.separator.join(result),
|
||||
separator=self.separator
|
||||
)
|
||||
return result
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.path == other.path and self.separator == other.separator
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __unicode__(self):
|
||||
return self.path
|
||||
|
||||
|
||||
def get_attr(mixed, attr):
|
||||
if isinstance(mixed, InstrumentedAttribute):
|
||||
return getattr(
|
||||
mixed.property.mapper.class_,
|
||||
attr
|
||||
)
|
||||
else:
|
||||
return getattr(mixed, attr)
|
||||
|
||||
|
||||
@str_coercible
|
||||
class AttrPath:
|
||||
def __init__(self, class_, path):
|
||||
self.class_ = class_
|
||||
self.path = Path(path)
|
||||
self.parts = []
|
||||
last_attr = class_
|
||||
for value in self.path:
|
||||
last_attr = get_attr(last_attr, value)
|
||||
self.parts.append(last_attr)
|
||||
|
||||
def __iter__(self):
|
||||
yield from self.parts
|
||||
|
||||
def __invert__(self):
|
||||
def get_backref(part):
|
||||
prop = part.property
|
||||
backref = prop.backref or prop.back_populates
|
||||
if backref is None:
|
||||
raise Exception(
|
||||
"Invert failed because property '%s' of class "
|
||||
"%s has no backref." % (
|
||||
prop.key,
|
||||
prop.parent.class_.__name__
|
||||
)
|
||||
)
|
||||
if isinstance(backref, tuple):
|
||||
return backref[0]
|
||||
else:
|
||||
return backref
|
||||
|
||||
if isinstance(self.parts[-1].property, sa.orm.ColumnProperty):
|
||||
class_ = self.parts[-1].class_
|
||||
else:
|
||||
class_ = self.parts[-1].mapper.class_
|
||||
|
||||
return self.__class__(
|
||||
class_,
|
||||
'.'.join(map(get_backref, reversed(self.parts)))
|
||||
)
|
||||
|
||||
def index(self, element):
|
||||
for index, el in enumerate(self.parts):
|
||||
if el is element:
|
||||
return index
|
||||
|
||||
@property
|
||||
def direction(self):
|
||||
symbols = [part.property.direction for part in self.parts]
|
||||
if symbol('MANYTOMANY') in symbols:
|
||||
return symbol('MANYTOMANY')
|
||||
elif symbol('MANYTOONE') in symbols and symbol('ONETOMANY') in symbols:
|
||||
return symbol('MANYTOMANY')
|
||||
return symbols[0]
|
||||
|
||||
@property
|
||||
def uselist(self):
|
||||
return any(part.property.uselist for part in self.parts)
|
||||
|
||||
def __getitem__(self, slice):
|
||||
result = self.parts[slice]
|
||||
if isinstance(result, list) and result:
|
||||
if result[0] is self.parts[0]:
|
||||
class_ = self.class_
|
||||
else:
|
||||
class_ = result[0].parent.class_
|
||||
return self.__class__(
|
||||
class_,
|
||||
self.path[slice]
|
||||
)
|
||||
else:
|
||||
return result
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path)
|
||||
|
||||
def __repr__(self):
|
||||
return "{}({}, {!r})".format(
|
||||
self.__class__.__name__,
|
||||
self.class_.__name__,
|
||||
self.path.path
|
||||
)
|
||||
|
||||
def __eq__(self, other):
|
||||
return self.path == other.path and self.class_ == other.class_
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __unicode__(self):
|
||||
return str(self.path)
|
|
@ -0,0 +1,5 @@
|
|||
from .country import Country # noqa
|
||||
from .currency import Currency # noqa
|
||||
from .ltree import Ltree # noqa
|
||||
from .weekday import WeekDay # noqa
|
||||
from .weekdays import WeekDays # noqa
|
|
@ -0,0 +1,110 @@
|
|||
from functools import total_ordering
|
||||
|
||||
from .. import i18n
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@total_ordering
|
||||
@str_coercible
|
||||
class Country:
|
||||
"""
|
||||
Country class wraps a 2 to 3 letter country code. It provides various
|
||||
convenience properties and methods.
|
||||
|
||||
::
|
||||
|
||||
from babel import Locale
|
||||
from sqlalchemy_utils import Country, i18n
|
||||
|
||||
|
||||
# First lets add a locale getter for testing purposes
|
||||
i18n.get_locale = lambda: Locale('en')
|
||||
|
||||
|
||||
Country('FI').name # Finland
|
||||
Country('FI').code # FI
|
||||
|
||||
Country(Country('FI')).code # 'FI'
|
||||
|
||||
Country always validates the given code if you use at least the optional
|
||||
dependency list 'babel', otherwise no validation are performed.
|
||||
|
||||
::
|
||||
|
||||
Country(None) # raises TypeError
|
||||
|
||||
Country('UnknownCode') # raises ValueError
|
||||
|
||||
|
||||
Country supports equality operators.
|
||||
|
||||
::
|
||||
|
||||
Country('FI') == Country('FI')
|
||||
Country('FI') != Country('US')
|
||||
|
||||
|
||||
Country objects are hashable.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert hash(Country('FI')) == hash('FI')
|
||||
|
||||
"""
|
||||
def __init__(self, code_or_country):
|
||||
if isinstance(code_or_country, Country):
|
||||
self.code = code_or_country.code
|
||||
elif isinstance(code_or_country, str):
|
||||
self.validate(code_or_country)
|
||||
self.code = code_or_country
|
||||
else:
|
||||
raise TypeError(
|
||||
"Country() argument must be a string or a country, not '{}'"
|
||||
.format(
|
||||
type(code_or_country).__name__
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return i18n.get_locale().territories[self.code]
|
||||
|
||||
@classmethod
|
||||
def validate(self, code):
|
||||
try:
|
||||
i18n.babel.Locale('en').territories[code]
|
||||
except KeyError:
|
||||
raise ValueError(
|
||||
f'Could not convert string to country code: {code}'
|
||||
)
|
||||
except AttributeError:
|
||||
# As babel is optional, we may raise an AttributeError accessing it
|
||||
pass
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Country):
|
||||
return self.code == other.code
|
||||
elif isinstance(other, str):
|
||||
return self.code == other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.code)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __lt__(self, other):
|
||||
if isinstance(other, Country):
|
||||
return self.code < other.code
|
||||
elif isinstance(other, str):
|
||||
return self.code < other
|
||||
return NotImplemented
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.code!r})'
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
|
@ -0,0 +1,109 @@
|
|||
from .. import i18n, ImproperlyConfigured
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
class Currency:
|
||||
"""
|
||||
Currency class wraps a 3-letter currency code. It provides various
|
||||
convenience properties and methods.
|
||||
|
||||
::
|
||||
|
||||
from babel import Locale
|
||||
from sqlalchemy_utils import Currency, i18n
|
||||
|
||||
|
||||
# First lets add a locale getter for testing purposes
|
||||
i18n.get_locale = lambda: Locale('en')
|
||||
|
||||
|
||||
Currency('USD').name # US Dollar
|
||||
Currency('USD').symbol # $
|
||||
|
||||
Currency(Currency('USD')).code # 'USD'
|
||||
|
||||
Currency always validates the given code if you use at least the optional
|
||||
dependency list 'babel', otherwise no validation are performed.
|
||||
|
||||
::
|
||||
|
||||
Currency(None) # raises TypeError
|
||||
|
||||
Currency('UnknownCode') # raises ValueError
|
||||
|
||||
|
||||
Currency supports equality operators.
|
||||
|
||||
::
|
||||
|
||||
Currency('USD') == Currency('USD')
|
||||
Currency('USD') != Currency('EUR')
|
||||
|
||||
|
||||
Currencies are hashable.
|
||||
|
||||
|
||||
::
|
||||
|
||||
len(set([Currency('USD'), Currency('USD')])) # 1
|
||||
|
||||
|
||||
"""
|
||||
def __init__(self, code):
|
||||
if i18n.babel is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'babel' package is required in order to use Currency class."
|
||||
)
|
||||
if isinstance(code, Currency):
|
||||
self.code = code
|
||||
elif isinstance(code, str):
|
||||
self.validate(code)
|
||||
self.code = code
|
||||
else:
|
||||
raise TypeError(
|
||||
'First argument given to Currency constructor should be '
|
||||
'either an instance of Currency or valid three letter '
|
||||
'currency code.'
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(self, code):
|
||||
try:
|
||||
i18n.babel.Locale('en').currencies[code]
|
||||
except KeyError:
|
||||
raise ValueError(f"'{code}' is not valid currency code.")
|
||||
except AttributeError:
|
||||
# As babel is optional, we may raise an AttributeError accessing it
|
||||
pass
|
||||
|
||||
@property
|
||||
def symbol(self):
|
||||
return i18n.babel.numbers.get_currency_symbol(
|
||||
self.code,
|
||||
i18n.get_locale()
|
||||
)
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return i18n.get_locale().currencies[self.code]
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Currency):
|
||||
return self.code == other.code
|
||||
elif isinstance(other, str):
|
||||
return self.code == other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.code)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.code!r})'
|
||||
|
||||
def __unicode__(self):
|
||||
return self.code
|
|
@ -0,0 +1,220 @@
|
|||
import re
|
||||
|
||||
from ..utils import str_coercible
|
||||
|
||||
path_matcher = re.compile(r'^[A-Za-z0-9_]+(\.[A-Za-z0-9_]+)*$')
|
||||
|
||||
|
||||
@str_coercible
|
||||
class Ltree:
|
||||
"""
|
||||
Ltree class wraps a valid string label path. It provides various
|
||||
convenience properties and methods.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import Ltree
|
||||
|
||||
Ltree('1.2.3').path # '1.2.3'
|
||||
|
||||
|
||||
Ltree always validates the given path.
|
||||
|
||||
::
|
||||
|
||||
Ltree(None) # raises TypeError
|
||||
|
||||
Ltree('..') # raises ValueError
|
||||
|
||||
|
||||
Validator is also available as class method.
|
||||
|
||||
::
|
||||
|
||||
Ltree.validate('1.2.3')
|
||||
Ltree.validate(None) # raises TypeError
|
||||
|
||||
|
||||
Ltree supports equality operators.
|
||||
|
||||
::
|
||||
|
||||
Ltree('Countries.Finland') == Ltree('Countries.Finland')
|
||||
Ltree('Countries.Germany') != Ltree('Countries.Finland')
|
||||
|
||||
|
||||
Ltree objects are hashable.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert hash(Ltree('Finland')) == hash('Finland')
|
||||
|
||||
|
||||
Ltree objects have length.
|
||||
|
||||
::
|
||||
|
||||
assert len(Ltree('1.2')) == 2
|
||||
assert len(Ltree('some.one.some.where')) # 4
|
||||
|
||||
|
||||
You can easily find subpath indexes.
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3').index('2.3') == 1
|
||||
assert Ltree('1.2.3.4.5').index('3.4') == 2
|
||||
|
||||
|
||||
Ltree objects can be sliced.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3')[0:2] == Ltree('1.2')
|
||||
assert Ltree('1.2.3')[1:] == Ltree('2.3')
|
||||
|
||||
|
||||
Finding longest common ancestor.
|
||||
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
|
||||
assert Ltree('1.2.3.4.5').lca('1.2', '1.2.3') == '1'
|
||||
|
||||
|
||||
Ltree objects can be concatenated.
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2') + Ltree('1.2') == Ltree('1.2.1.2')
|
||||
"""
|
||||
|
||||
def __init__(self, path_or_ltree):
|
||||
if isinstance(path_or_ltree, Ltree):
|
||||
self.path = path_or_ltree.path
|
||||
elif isinstance(path_or_ltree, str):
|
||||
self.validate(path_or_ltree)
|
||||
self.path = path_or_ltree
|
||||
else:
|
||||
raise TypeError(
|
||||
"Ltree() argument must be a string or an Ltree, not '{}'"
|
||||
.format(
|
||||
type(path_or_ltree).__name__
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate(cls, path):
|
||||
if path_matcher.match(path) is None:
|
||||
raise ValueError(
|
||||
f"'{path}' is not a valid ltree path."
|
||||
)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.path.split('.'))
|
||||
|
||||
def index(self, other):
|
||||
subpath = Ltree(other).path.split('.')
|
||||
parts = self.path.split('.')
|
||||
for index, _ in enumerate(parts):
|
||||
if parts[index:len(subpath) + index] == subpath:
|
||||
return index
|
||||
raise ValueError('subpath not found')
|
||||
|
||||
def descendant_of(self, other):
|
||||
"""
|
||||
is left argument a descendant of right (or equal)?
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3.4.5').descendant_of('1.2.3')
|
||||
"""
|
||||
subpath = self[:len(Ltree(other))]
|
||||
return subpath == other
|
||||
|
||||
def ancestor_of(self, other):
|
||||
"""
|
||||
is left argument an ancestor of right (or equal)?
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3').ancestor_of('1.2.3.4.5')
|
||||
"""
|
||||
subpath = Ltree(other)[:len(self)]
|
||||
return subpath == self
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, int):
|
||||
return Ltree(self.path.split('.')[key])
|
||||
elif isinstance(key, slice):
|
||||
return Ltree('.'.join(self.path.split('.')[key]))
|
||||
raise TypeError(
|
||||
'Ltree indices must be integers, not {}'.format(
|
||||
key.__class__.__name__
|
||||
)
|
||||
)
|
||||
|
||||
def lca(self, *others):
|
||||
"""
|
||||
Lowest common ancestor, i.e., longest common prefix of paths
|
||||
|
||||
::
|
||||
|
||||
assert Ltree('1.2.3.4.5').lca('1.2.3', '1.2.3.4', '1.2.3') == '1.2'
|
||||
"""
|
||||
other_parts = [Ltree(other).path.split('.') for other in others]
|
||||
parts = self.path.split('.')
|
||||
for index, element in enumerate(parts):
|
||||
if any(
|
||||
other[index] != element or
|
||||
len(other) <= index + 1 or
|
||||
len(parts) == index + 1
|
||||
for other in other_parts
|
||||
):
|
||||
if index == 0:
|
||||
return None
|
||||
return Ltree('.'.join(parts[0:index]))
|
||||
|
||||
def __add__(self, other):
|
||||
return Ltree(self.path + '.' + Ltree(other).path)
|
||||
|
||||
def __radd__(self, other):
|
||||
return Ltree(other) + self
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Ltree):
|
||||
return self.path == other.path
|
||||
elif isinstance(other, str):
|
||||
return self.path == other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.path)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.path!r})'
|
||||
|
||||
def __unicode__(self):
|
||||
return self.path
|
||||
|
||||
def __contains__(self, label):
|
||||
return label in self.path.split('.')
|
||||
|
||||
def __gt__(self, other):
|
||||
return self.path > other.path
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.path < other.path
|
||||
|
||||
def __ge__(self, other):
|
||||
return self.path >= other.path
|
||||
|
||||
def __le__(self, other):
|
||||
return self.path <= other.path
|
|
@ -0,0 +1,54 @@
|
|||
from functools import total_ordering
|
||||
|
||||
from .. import i18n
|
||||
from ..utils import str_coercible
|
||||
|
||||
|
||||
@str_coercible
|
||||
@total_ordering
|
||||
class WeekDay:
|
||||
NUM_WEEK_DAYS = 7
|
||||
|
||||
def __init__(self, index):
|
||||
if not (0 <= index < self.NUM_WEEK_DAYS):
|
||||
raise ValueError(
|
||||
"index must be between 0 and %d" % self.NUM_WEEK_DAYS
|
||||
)
|
||||
self.index = index
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, WeekDay):
|
||||
return self.index == other.index
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.index)
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.position < other.position
|
||||
|
||||
def __repr__(self):
|
||||
return f'{self.__class__.__name__}({self.index!r})'
|
||||
|
||||
def __unicode__(self):
|
||||
return self.name
|
||||
|
||||
def get_name(self, width='wide', context='format'):
|
||||
names = i18n.babel.dates.get_day_names(
|
||||
width,
|
||||
context,
|
||||
i18n.get_locale()
|
||||
)
|
||||
return names[self.index]
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self.get_name()
|
||||
|
||||
@property
|
||||
def position(self):
|
||||
return (
|
||||
self.index -
|
||||
i18n.get_locale().first_week_day
|
||||
) % self.NUM_WEEK_DAYS
|
|
@ -0,0 +1,57 @@
|
|||
from ..utils import str_coercible
|
||||
from .weekday import WeekDay
|
||||
|
||||
|
||||
@str_coercible
|
||||
class WeekDays:
|
||||
def __init__(self, bit_string_or_week_days):
|
||||
if isinstance(bit_string_or_week_days, str):
|
||||
self._days = set()
|
||||
|
||||
if len(bit_string_or_week_days) != WeekDay.NUM_WEEK_DAYS:
|
||||
raise ValueError(
|
||||
'Bit string must be {} characters long.'.format(
|
||||
WeekDay.NUM_WEEK_DAYS
|
||||
)
|
||||
)
|
||||
|
||||
for index, bit in enumerate(bit_string_or_week_days):
|
||||
if bit not in '01':
|
||||
raise ValueError(
|
||||
'Bit string may only contain zeroes and ones.'
|
||||
)
|
||||
if bit == '1':
|
||||
self._days.add(WeekDay(index))
|
||||
elif isinstance(bit_string_or_week_days, WeekDays):
|
||||
self._days = bit_string_or_week_days._days
|
||||
else:
|
||||
self._days = set(bit_string_or_week_days)
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, WeekDays):
|
||||
return self._days == other._days
|
||||
elif isinstance(other, str):
|
||||
return self.as_bit_string() == other
|
||||
else:
|
||||
return NotImplemented
|
||||
|
||||
def __iter__(self):
|
||||
yield from sorted(self._days)
|
||||
|
||||
def __contains__(self, value):
|
||||
return value in self._days
|
||||
|
||||
def __repr__(self):
|
||||
return '{}({!r})'.format(
|
||||
self.__class__.__name__,
|
||||
self.as_bit_string()
|
||||
)
|
||||
|
||||
def __unicode__(self):
|
||||
return ', '.join(str(day) for day in self)
|
||||
|
||||
def as_bit_string(self):
|
||||
return ''.join(
|
||||
'1' if WeekDay(index) in self._days else '0'
|
||||
for index in range(WeekDay.NUM_WEEK_DAYS)
|
||||
)
|
|
@ -0,0 +1,84 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
|
||||
class ProxyDict:
|
||||
def __init__(self, parent, collection_name, mapping_attr):
|
||||
self.parent = parent
|
||||
self.collection_name = collection_name
|
||||
self.child_class = mapping_attr.class_
|
||||
self.key_name = mapping_attr.key
|
||||
self.cache = {}
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
return getattr(self.parent, self.collection_name)
|
||||
|
||||
def keys(self):
|
||||
descriptor = getattr(self.child_class, self.key_name)
|
||||
return [x[0] for x in self.collection.values(descriptor)]
|
||||
|
||||
def __contains__(self, key):
|
||||
if key in self.cache:
|
||||
return self.cache[key] is not None
|
||||
return self.fetch(key) is not None
|
||||
|
||||
def has_key(self, key):
|
||||
return self.__contains__(key)
|
||||
|
||||
def fetch(self, key):
|
||||
session = sa.orm.object_session(self.parent)
|
||||
if session and sa.orm.util.has_identity(self.parent):
|
||||
obj = self.collection.filter_by(**{self.key_name: key}).first()
|
||||
self.cache[key] = obj
|
||||
return obj
|
||||
|
||||
def create_new_instance(self, key):
|
||||
value = self.child_class(**{self.key_name: key})
|
||||
self.collection.append(value)
|
||||
self.cache[key] = value
|
||||
return value
|
||||
|
||||
def __getitem__(self, key):
|
||||
if key in self.cache:
|
||||
if self.cache[key] is not None:
|
||||
return self.cache[key]
|
||||
else:
|
||||
value = self.fetch(key)
|
||||
if value:
|
||||
return value
|
||||
|
||||
return self.create_new_instance(key)
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
try:
|
||||
existing = self[key]
|
||||
self.collection.remove(existing)
|
||||
except KeyError:
|
||||
pass
|
||||
self.collection.append(value)
|
||||
self.cache[key] = value
|
||||
|
||||
|
||||
def proxy_dict(parent, collection_name, mapping_attr):
|
||||
try:
|
||||
parent._proxy_dicts
|
||||
except AttributeError:
|
||||
parent._proxy_dicts = {}
|
||||
|
||||
try:
|
||||
return parent._proxy_dicts[collection_name]
|
||||
except KeyError:
|
||||
parent._proxy_dicts[collection_name] = ProxyDict(
|
||||
parent,
|
||||
collection_name,
|
||||
mapping_attr
|
||||
)
|
||||
return parent._proxy_dicts[collection_name]
|
||||
|
||||
|
||||
def expire_proxy_dicts(target, context):
|
||||
if hasattr(target, '_proxy_dicts'):
|
||||
target._proxy_dicts = {}
|
||||
|
||||
|
||||
sa.event.listen(sa.orm.Mapper, 'expire', expire_proxy_dicts)
|
|
@ -0,0 +1,173 @@
|
|||
"""
|
||||
QueryChain is a wrapper for sequence of queries.
|
||||
|
||||
|
||||
Features:
|
||||
|
||||
* Easy iteration for sequence of queries
|
||||
* Limit, offset and count which are applied to all queries in the chain
|
||||
* Smart __getitem__ support
|
||||
|
||||
|
||||
Initialization
|
||||
^^^^^^^^^^^^^^
|
||||
|
||||
QueryChain takes iterable of queries as first argument. Additionally limit and
|
||||
offset parameters can be given
|
||||
|
||||
::
|
||||
|
||||
chain = QueryChain([session.query(User), session.query(Article)])
|
||||
|
||||
chain = QueryChain(
|
||||
[session.query(User), session.query(Article)],
|
||||
limit=4
|
||||
)
|
||||
|
||||
|
||||
Simple iteration
|
||||
^^^^^^^^^^^^^^^^
|
||||
::
|
||||
|
||||
chain = QueryChain([session.query(User), session.query(Article)])
|
||||
|
||||
for obj in chain:
|
||||
print obj
|
||||
|
||||
|
||||
Limit and offset
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
Lets say you have 5 blog posts, 5 articles and 5 news items in your
|
||||
database.
|
||||
|
||||
::
|
||||
|
||||
chain = QueryChain(
|
||||
[
|
||||
session.query(BlogPost),
|
||||
session.query(Article),
|
||||
session.query(NewsItem)
|
||||
],
|
||||
limit=5
|
||||
)
|
||||
|
||||
list(chain) # all blog posts but not articles and news items
|
||||
|
||||
|
||||
chain = chain.offset(4)
|
||||
list(chain) # last blog post, and first four articles
|
||||
|
||||
|
||||
Just like with original query object the limit and offset can be chained to
|
||||
return a new QueryChain.
|
||||
|
||||
::
|
||||
|
||||
chain = chain.limit(5).offset(7)
|
||||
|
||||
|
||||
Chain slicing
|
||||
^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
chain = QueryChain(
|
||||
[
|
||||
session.query(BlogPost),
|
||||
session.query(Article),
|
||||
session.query(NewsItem)
|
||||
]
|
||||
)
|
||||
|
||||
chain[3:6] # New QueryChain with offset=3 and limit=6
|
||||
|
||||
|
||||
Count
|
||||
^^^^^
|
||||
|
||||
Let's assume that there are five blog posts, five articles and five news
|
||||
items in the database, and you have the following query chain::
|
||||
|
||||
chain = QueryChain(
|
||||
[
|
||||
session.query(BlogPost),
|
||||
session.query(Article),
|
||||
session.query(NewsItem)
|
||||
]
|
||||
)
|
||||
|
||||
You can then get the total number rows returned by the query chain
|
||||
with :meth:`~QueryChain.count`::
|
||||
|
||||
>>> chain.count()
|
||||
15
|
||||
|
||||
|
||||
"""
|
||||
from copy import copy
|
||||
|
||||
|
||||
class QueryChain:
|
||||
"""
|
||||
QueryChain can be used as a wrapper for sequence of queries.
|
||||
|
||||
:param queries: A sequence of SQLAlchemy Query objects
|
||||
:param limit: Similar to normal query limit this parameter can be used for
|
||||
limiting the number of results for the whole query chain.
|
||||
:param offset: Similar to normal query offset this parameter can be used
|
||||
for offsetting the query chain as a whole.
|
||||
|
||||
.. versionadded: 0.26.0
|
||||
"""
|
||||
def __init__(self, queries, limit=None, offset=None):
|
||||
self.queries = queries
|
||||
self._limit = limit
|
||||
self._offset = offset
|
||||
|
||||
def __iter__(self):
|
||||
consumed = 0
|
||||
skipped = 0
|
||||
for query in self.queries:
|
||||
query_copy = copy(query)
|
||||
if self._limit:
|
||||
query = query.limit(self._limit - consumed)
|
||||
if self._offset:
|
||||
query = query.offset(self._offset - skipped)
|
||||
|
||||
obj_count = 0
|
||||
for obj in query:
|
||||
consumed += 1
|
||||
obj_count += 1
|
||||
yield obj
|
||||
|
||||
if not obj_count:
|
||||
skipped += query_copy.count()
|
||||
else:
|
||||
skipped += obj_count
|
||||
|
||||
def limit(self, value):
|
||||
return self[:value]
|
||||
|
||||
def offset(self, value):
|
||||
return self[value:]
|
||||
|
||||
def count(self):
|
||||
"""
|
||||
Return the total number of rows this QueryChain's queries would return.
|
||||
"""
|
||||
return sum(q.count() for q in self.queries)
|
||||
|
||||
def __getitem__(self, key):
|
||||
if isinstance(key, slice):
|
||||
return self.__class__(
|
||||
queries=self.queries,
|
||||
limit=key.stop if key.stop is not None else self._limit,
|
||||
offset=key.start if key.start is not None else self._offset
|
||||
)
|
||||
else:
|
||||
for obj in self[key:1]:
|
||||
return obj
|
||||
|
||||
def __repr__(self):
|
||||
return '<QueryChain at 0x%x>' % id(self)
|
|
@ -0,0 +1,128 @@
|
|||
import sqlalchemy as sa
|
||||
import sqlalchemy.orm
|
||||
from sqlalchemy.sql.util import ClauseAdapter
|
||||
|
||||
from ..compat import _select_args
|
||||
from .chained_join import chained_join # noqa
|
||||
|
||||
|
||||
def path_to_relationships(path, cls):
|
||||
relationships = []
|
||||
for path_name in path.split('.'):
|
||||
rel = getattr(cls, path_name)
|
||||
relationships.append(rel)
|
||||
cls = rel.mapper.class_
|
||||
return relationships
|
||||
|
||||
|
||||
def adapt_expr(expr, *selectables):
|
||||
for selectable in selectables:
|
||||
expr = ClauseAdapter(selectable).traverse(expr)
|
||||
return expr
|
||||
|
||||
|
||||
def inverse_join(selectable, left_alias, right_alias, relationship):
|
||||
if relationship.property.secondary is not None:
|
||||
secondary_alias = sa.alias(relationship.property.secondary)
|
||||
return selectable.join(
|
||||
secondary_alias,
|
||||
adapt_expr(
|
||||
relationship.property.secondaryjoin,
|
||||
sa.inspect(left_alias).selectable,
|
||||
secondary_alias
|
||||
)
|
||||
).join(
|
||||
right_alias,
|
||||
adapt_expr(
|
||||
relationship.property.primaryjoin,
|
||||
sa.inspect(right_alias).selectable,
|
||||
secondary_alias
|
||||
)
|
||||
)
|
||||
else:
|
||||
join = sa.orm.join(right_alias, left_alias, relationship)
|
||||
onclause = join.onclause
|
||||
return selectable.join(right_alias, onclause)
|
||||
|
||||
|
||||
def relationship_to_correlation(relationship, alias):
|
||||
if relationship.property.secondary is not None:
|
||||
return adapt_expr(
|
||||
relationship.property.primaryjoin,
|
||||
alias,
|
||||
)
|
||||
else:
|
||||
return sa.orm.join(
|
||||
relationship.parent,
|
||||
alias,
|
||||
relationship
|
||||
).onclause
|
||||
|
||||
|
||||
def chained_inverse_join(relationships, leaf_model):
|
||||
selectable = sa.inspect(leaf_model).selectable
|
||||
aliases = [leaf_model]
|
||||
for index, relationship in enumerate(relationships[1:]):
|
||||
aliases.append(sa.orm.aliased(relationship.mapper.class_))
|
||||
selectable = inverse_join(
|
||||
selectable,
|
||||
aliases[index],
|
||||
aliases[index + 1],
|
||||
relationships[index]
|
||||
)
|
||||
|
||||
if relationships[-1].property.secondary is not None:
|
||||
secondary_alias = sa.alias(relationships[-1].property.secondary)
|
||||
selectable = selectable.join(
|
||||
secondary_alias,
|
||||
adapt_expr(
|
||||
relationships[-1].property.secondaryjoin,
|
||||
secondary_alias,
|
||||
sa.inspect(aliases[-1]).selectable
|
||||
)
|
||||
)
|
||||
aliases.append(secondary_alias)
|
||||
return selectable, aliases
|
||||
|
||||
|
||||
def select_correlated_expression(
|
||||
root_model,
|
||||
expr,
|
||||
path,
|
||||
leaf_model,
|
||||
from_obj=None,
|
||||
order_by=None,
|
||||
correlate=True
|
||||
):
|
||||
relationships = list(reversed(path_to_relationships(path, root_model)))
|
||||
|
||||
query = sa.select(*_select_args(expr))
|
||||
|
||||
join_expr, aliases = chained_inverse_join(relationships, leaf_model)
|
||||
|
||||
if order_by:
|
||||
query = query.order_by(
|
||||
*[
|
||||
adapt_expr(
|
||||
o,
|
||||
*(sa.inspect(alias).selectable for alias in aliases)
|
||||
)
|
||||
for o in order_by
|
||||
]
|
||||
)
|
||||
|
||||
condition = relationship_to_correlation(
|
||||
relationships[-1],
|
||||
aliases[-1]
|
||||
)
|
||||
|
||||
if from_obj is not None:
|
||||
condition = adapt_expr(condition, from_obj)
|
||||
|
||||
query = query.select_from(join_expr.selectable)
|
||||
|
||||
if correlate:
|
||||
query = query.correlate(
|
||||
from_obj if from_obj is not None else root_model
|
||||
)
|
||||
return query.where(condition)
|
|
@ -0,0 +1,31 @@
|
|||
def chained_join(*relationships):
|
||||
"""
|
||||
Return a chained Join object for given relationships.
|
||||
"""
|
||||
property_ = relationships[0].property
|
||||
|
||||
if property_.secondary is not None:
|
||||
from_ = property_.secondary.join(
|
||||
property_.mapper.class_.__table__,
|
||||
property_.secondaryjoin
|
||||
)
|
||||
else:
|
||||
from_ = property_.mapper.class_.__table__
|
||||
for relationship in relationships[1:]:
|
||||
prop = relationship.property
|
||||
if prop.secondary is not None:
|
||||
from_ = from_.join(
|
||||
prop.secondary,
|
||||
prop.primaryjoin
|
||||
)
|
||||
|
||||
from_ = from_.join(
|
||||
prop.mapper.class_,
|
||||
prop.secondaryjoin
|
||||
)
|
||||
else:
|
||||
from_ = from_.join(
|
||||
prop.mapper.class_,
|
||||
prop.primaryjoin
|
||||
)
|
||||
return from_
|
|
@ -0,0 +1,63 @@
|
|||
from functools import wraps
|
||||
|
||||
from sqlalchemy.orm.collections import InstrumentedList as _InstrumentedList
|
||||
|
||||
from .arrow import ArrowType # noqa
|
||||
from .choice import Choice, ChoiceType # noqa
|
||||
from .color import ColorType # noqa
|
||||
from .country import CountryType # noqa
|
||||
from .currency import CurrencyType # noqa
|
||||
from .email import EmailType # noqa
|
||||
from .encrypted.encrypted_type import ( # noqa
|
||||
EncryptedType,
|
||||
StringEncryptedType
|
||||
)
|
||||
from .enriched_datetime.enriched_date_type import EnrichedDateType # noqa
|
||||
from .ip_address import IPAddressType # noqa
|
||||
from .json import JSONType # noqa
|
||||
from .locale import LocaleType # noqa
|
||||
from .ltree import LtreeType # noqa
|
||||
from .password import Password, PasswordType # noqa
|
||||
from .pg_composite import ( # noqa
|
||||
CompositeType,
|
||||
register_composites,
|
||||
remove_composite_listeners
|
||||
)
|
||||
from .phone_number import ( # noqa
|
||||
PhoneNumber,
|
||||
PhoneNumberParseException,
|
||||
PhoneNumberType
|
||||
)
|
||||
from .range import ( # noqa
|
||||
DateRangeType,
|
||||
DateTimeRangeType,
|
||||
Int8RangeType,
|
||||
IntRangeType,
|
||||
NumericRangeType
|
||||
)
|
||||
from .scalar_list import ScalarListException, ScalarListType # noqa
|
||||
from .timezone import TimezoneType # noqa
|
||||
from .ts_vector import TSVectorType # noqa
|
||||
from .url import URLType # noqa
|
||||
from .uuid import UUIDType # noqa
|
||||
from .weekdays import WeekDaysType # noqa
|
||||
|
||||
from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType # noqa isort:skip
|
||||
|
||||
|
||||
class InstrumentedList(_InstrumentedList):
|
||||
"""Enhanced version of SQLAlchemy InstrumentedList. Provides some
|
||||
additional functionality."""
|
||||
|
||||
def any(self, attr):
|
||||
return any(getattr(item, attr) for item in self)
|
||||
|
||||
def all(self, attr):
|
||||
return all(getattr(item, attr) for item in self)
|
||||
|
||||
|
||||
def instrumented_list(f):
|
||||
@wraps(f)
|
||||
def wrapper(*args, **kwargs):
|
||||
return InstrumentedList([item for item in f(*args, **kwargs)])
|
||||
return wrapper
|
|
@ -0,0 +1,62 @@
|
|||
from ..exceptions import ImproperlyConfigured
|
||||
from .enriched_datetime import ArrowDateTime
|
||||
from .enriched_datetime.enriched_datetime_type import EnrichedDateTimeType
|
||||
|
||||
arrow = None
|
||||
try:
|
||||
import arrow
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ArrowType(EnrichedDateTimeType):
|
||||
"""
|
||||
ArrowType provides way of saving Arrow_ objects into database. It
|
||||
automatically changes Arrow_ objects to datetime objects on the way in and
|
||||
datetime objects back to Arrow_ objects on the way out (when querying
|
||||
database). ArrowType needs Arrow_ library installed.
|
||||
|
||||
.. _Arrow: https://github.com/arrow-py/arrow
|
||||
|
||||
::
|
||||
|
||||
from datetime import datetime
|
||||
from sqlalchemy_utils import ArrowType
|
||||
import arrow
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'article'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
created_at = sa.Column(ArrowType)
|
||||
|
||||
|
||||
|
||||
article = Article(created_at=arrow.utcnow())
|
||||
|
||||
|
||||
As you may expect all the arrow goodies come available:
|
||||
|
||||
::
|
||||
|
||||
|
||||
article.created_at = article.created_at.replace(hours=-1)
|
||||
|
||||
article.created_at.humanize()
|
||||
# 'an hour ago'
|
||||
|
||||
"""
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if not arrow:
|
||||
raise ImproperlyConfigured(
|
||||
"'arrow' package is required to use 'ArrowType'"
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
datetime_processor=ArrowDateTime,
|
||||
*args,
|
||||
**kwargs
|
||||
)
|
|
@ -0,0 +1,24 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import BIT
|
||||
|
||||
|
||||
class BitType(sa.types.TypeDecorator):
|
||||
"""
|
||||
BitType offers way of saving BITs into database.
|
||||
"""
|
||||
impl = sa.types.BINARY
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, length=1, **kwargs):
|
||||
self.length = length
|
||||
sa.types.TypeDecorator.__init__(self, **kwargs)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
# Use the native BIT type for drivers that has it.
|
||||
if dialect.name == 'postgresql':
|
||||
return dialect.type_descriptor(BIT(self.length))
|
||||
elif dialect.name == 'sqlite':
|
||||
return dialect.type_descriptor(sa.String(self.length))
|
||||
else:
|
||||
return dialect.type_descriptor(type(self.impl)(self.length))
|
|
@ -0,0 +1,225 @@
|
|||
from enum import Enum
|
||||
|
||||
from sqlalchemy import types
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class Choice:
|
||||
def __init__(self, code, value):
|
||||
self.code = code
|
||||
self.value = value
|
||||
|
||||
def __eq__(self, other):
|
||||
if isinstance(other, Choice):
|
||||
return self.code == other.code
|
||||
return other == self.code
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.code)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
def __str__(self):
|
||||
return str(self.value)
|
||||
|
||||
def __repr__(self):
|
||||
return 'Choice(code={code}, value={value})'.format(
|
||||
code=self.code,
|
||||
value=self.value
|
||||
)
|
||||
|
||||
|
||||
class ChoiceType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
ChoiceType offers way of having fixed set of choices for given column. It
|
||||
could work with a list of tuple (a collection of key-value pairs), or
|
||||
integrate with :mod:`enum` in the standard library of Python 3.
|
||||
|
||||
Columns with ChoiceTypes are automatically coerced to Choice objects while
|
||||
a list of tuple been passed to the constructor. If a subclass of
|
||||
:class:`enum.Enum` is passed, columns will be coerced to :class:`enum.Enum`
|
||||
objects instead.
|
||||
|
||||
::
|
||||
|
||||
class User(Base):
|
||||
TYPES = [
|
||||
('admin', 'Admin'),
|
||||
('regular-user', 'Regular user')
|
||||
]
|
||||
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(ChoiceType(TYPES))
|
||||
|
||||
|
||||
user = User(type='admin')
|
||||
user.type # Choice(code='admin', value='Admin')
|
||||
|
||||
Or::
|
||||
|
||||
import enum
|
||||
|
||||
|
||||
class UserType(enum.Enum):
|
||||
admin = 1
|
||||
regular = 2
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(ChoiceType(UserType, impl=sa.Integer()))
|
||||
|
||||
|
||||
user = User(type=1)
|
||||
user.type # <UserType.admin: 1>
|
||||
|
||||
|
||||
ChoiceType is very useful when the rendered values change based on user's
|
||||
locale:
|
||||
|
||||
::
|
||||
|
||||
from babel import lazy_gettext as _
|
||||
|
||||
|
||||
class User(Base):
|
||||
TYPES = [
|
||||
('admin', _('Admin')),
|
||||
('regular-user', _('Regular user'))
|
||||
]
|
||||
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(ChoiceType(TYPES))
|
||||
|
||||
|
||||
user = User(type='admin')
|
||||
user.type # Choice(code='admin', value='Admin')
|
||||
|
||||
print user.type # 'Admin'
|
||||
|
||||
Or::
|
||||
|
||||
from enum import Enum
|
||||
from babel import lazy_gettext as _
|
||||
|
||||
|
||||
class UserType(Enum):
|
||||
admin = 1
|
||||
regular = 2
|
||||
|
||||
|
||||
UserType.admin.label = _('Admin')
|
||||
UserType.regular.label = _('Regular user')
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
type = sa.Column(ChoiceType(UserType, impl=sa.Integer()))
|
||||
|
||||
|
||||
user = User(type=UserType.admin)
|
||||
user.type # <UserType.admin: 1>
|
||||
|
||||
print user.type.label # 'Admin'
|
||||
"""
|
||||
|
||||
impl = types.Unicode(255)
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, choices, impl=None):
|
||||
self.choices = tuple(choices) if isinstance(choices, list) else choices
|
||||
|
||||
if (
|
||||
Enum is not None and
|
||||
isinstance(choices, type) and
|
||||
issubclass(choices, Enum)
|
||||
):
|
||||
self.type_impl = EnumTypeImpl(enum_class=choices)
|
||||
else:
|
||||
self.type_impl = ChoiceTypeImpl(choices=choices)
|
||||
|
||||
if impl:
|
||||
self.impl = impl
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.python_type
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.type_impl._coerce(value)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self.type_impl.process_bind_param(value, dialect)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self.type_impl.process_result_value(value, dialect)
|
||||
|
||||
|
||||
class ChoiceTypeImpl:
|
||||
"""The implementation for the ``Choice`` usage."""
|
||||
|
||||
def __init__(self, choices):
|
||||
if not choices:
|
||||
raise ImproperlyConfigured(
|
||||
'ChoiceType needs list of choices defined.'
|
||||
)
|
||||
self.choices_dict = dict(choices)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, Choice):
|
||||
return value
|
||||
return Choice(value, self.choices_dict[value])
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value and isinstance(value, Choice):
|
||||
return value.code
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return Choice(value, self.choices_dict[value])
|
||||
return value
|
||||
|
||||
|
||||
class EnumTypeImpl:
|
||||
"""The implementation for the ``Enum`` usage."""
|
||||
|
||||
def __init__(self, enum_class):
|
||||
if Enum is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'enum34' package is required to use 'EnumType' in Python "
|
||||
"< 3.4"
|
||||
)
|
||||
if not issubclass(enum_class, Enum):
|
||||
raise ImproperlyConfigured(
|
||||
"EnumType needs a class of enum defined."
|
||||
)
|
||||
|
||||
self.enum_class = enum_class
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return self.enum_class(value)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return None
|
||||
return self.enum_class(value).value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self._coerce(value)
|
|
@ -0,0 +1,79 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
colour = None
|
||||
try:
|
||||
import colour
|
||||
python_colour_type = colour.Color
|
||||
except (ImportError, AttributeError):
|
||||
python_colour_type = None
|
||||
|
||||
|
||||
class ColorType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
ColorType provides a way for saving Color (from colour_ package) objects
|
||||
into database. ColorType saves Color objects as strings on the way in and
|
||||
converts them back to objects when querying the database.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from colour import Color
|
||||
from sqlalchemy_utils import ColorType
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = 'document'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(50))
|
||||
background_color = sa.Column(ColorType)
|
||||
|
||||
|
||||
document = Document()
|
||||
document.background_color = Color('#F5F5F5')
|
||||
session.commit()
|
||||
|
||||
|
||||
Querying the database returns Color objects:
|
||||
|
||||
::
|
||||
|
||||
document = session.query(Document).first()
|
||||
|
||||
document.background_color.hex
|
||||
# '#f5f5f5'
|
||||
|
||||
|
||||
.. _colour: https://github.com/vaab/colour
|
||||
"""
|
||||
STORE_FORMAT = 'hex'
|
||||
impl = types.Unicode(20)
|
||||
python_type = python_colour_type
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, max_length=20, *args, **kwargs):
|
||||
# Fail if colour is not found.
|
||||
if colour is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'colour' package is required to use 'ColorType'"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.impl = types.Unicode(max_length)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value and isinstance(value, colour.Color):
|
||||
return str(getattr(value, self.STORE_FORMAT))
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return colour.Color(value)
|
||||
return value
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, colour.Color):
|
||||
return colour.Color(value)
|
||||
return value
|
|
@ -0,0 +1,64 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..primitives import Country
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class CountryType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
Changes :class:`.Country` objects to a string representation on the way in
|
||||
and changes them back to :class:`.Country objects on the way out.
|
||||
|
||||
In order to use CountryType you need to install Babel_ first.
|
||||
|
||||
.. _Babel: https://babel.pocoo.org/
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import CountryType, Country
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
country = sa.Column(CountryType)
|
||||
|
||||
|
||||
user = User()
|
||||
user.country = Country('FI')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user.country # Country('FI')
|
||||
user.country.name # Finland
|
||||
|
||||
print user.country # Finland
|
||||
|
||||
|
||||
CountryType is scalar coercible::
|
||||
|
||||
|
||||
user.country = 'US'
|
||||
user.country # Country('US')
|
||||
"""
|
||||
impl = types.String(2)
|
||||
python_type = Country
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, Country):
|
||||
return value.code
|
||||
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return Country(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, Country):
|
||||
return Country(value)
|
||||
return value
|
|
@ -0,0 +1,74 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from .. import i18n, ImproperlyConfigured
|
||||
from ..primitives import Currency
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class CurrencyType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
Changes :class:`.Currency` objects to a string representation on the way in
|
||||
and changes them back to :class:`.Currency` objects on the way out.
|
||||
|
||||
In order to use CurrencyType you need to install Babel_ first.
|
||||
|
||||
.. _Babel: https://babel.pocoo.org/
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import CurrencyType, Currency
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
currency = sa.Column(CurrencyType)
|
||||
|
||||
|
||||
user = User()
|
||||
user.currency = Currency('USD')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user.currency # Currency('USD')
|
||||
user.currency.name # US Dollar
|
||||
|
||||
str(user.currency) # US Dollar
|
||||
user.currency.symbol # $
|
||||
|
||||
|
||||
|
||||
CurrencyType is scalar coercible::
|
||||
|
||||
|
||||
user.currency = 'US'
|
||||
user.currency # Currency('US')
|
||||
"""
|
||||
impl = types.String(3)
|
||||
python_type = Currency
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if i18n.babel is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'babel' package is required in order to use CurrencyType."
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, Currency):
|
||||
return value.code
|
||||
elif isinstance(value, str):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return Currency(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, Currency):
|
||||
return Currency(value)
|
||||
return value
|
|
@ -0,0 +1,48 @@
|
|||
import sqlalchemy as sa
|
||||
|
||||
from ..operators import CaseInsensitiveComparator
|
||||
|
||||
|
||||
class EmailType(sa.types.TypeDecorator):
|
||||
"""
|
||||
Provides a way for storing emails in a lower case.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import EmailType
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
email = sa.Column(EmailType)
|
||||
|
||||
|
||||
user = User()
|
||||
user.email = 'John.Smith@foo.com'
|
||||
user.name = 'John Smith'
|
||||
session.add(user)
|
||||
session.commit()
|
||||
# Notice - email in filter() is lowercase.
|
||||
user = (session.query(User)
|
||||
.filter(User.email == 'john.smith@foo.com')
|
||||
.one())
|
||||
assert user.name == 'John Smith'
|
||||
"""
|
||||
impl = sa.Unicode
|
||||
comparator_factory = CaseInsensitiveComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, length=255, *args, **kwargs):
|
||||
super().__init__(length=length, *args, **kwargs)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
return value.lower()
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
|
@ -0,0 +1 @@
|
|||
# Module for encrypted type
|
|
@ -0,0 +1,508 @@
|
|||
import base64
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from sqlalchemy.types import LargeBinary, String, TypeDecorator
|
||||
|
||||
from sqlalchemy_utils.exceptions import ImproperlyConfigured
|
||||
from sqlalchemy_utils.types.encrypted.padding import PADDING_MECHANISM
|
||||
from sqlalchemy_utils.types.json import JSONType
|
||||
from sqlalchemy_utils.types.scalar_coercible import ScalarCoercible
|
||||
|
||||
cryptography = None
|
||||
try:
|
||||
import cryptography
|
||||
from cryptography.exceptions import InvalidTag
|
||||
from cryptography.fernet import Fernet
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.ciphers import (
|
||||
algorithms,
|
||||
Cipher,
|
||||
modes
|
||||
)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
dateutil = None
|
||||
try:
|
||||
import dateutil
|
||||
from dateutil.parser import parse as datetime_parse
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class InvalidCiphertextError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class EncryptionDecryptionBaseEngine:
|
||||
"""A base encryption and decryption engine.
|
||||
|
||||
This class must be sub-classed in order to create
|
||||
new engines.
|
||||
"""
|
||||
|
||||
def _update_key(self, key):
|
||||
if isinstance(key, str):
|
||||
key = key.encode()
|
||||
digest = hashes.Hash(hashes.SHA256(), backend=default_backend())
|
||||
digest.update(key)
|
||||
engine_key = digest.finalize()
|
||||
|
||||
self._initialize_engine(engine_key)
|
||||
|
||||
def encrypt(self, value):
|
||||
raise NotImplementedError('Subclasses must implement this!')
|
||||
|
||||
def decrypt(self, value):
|
||||
raise NotImplementedError('Subclasses must implement this!')
|
||||
|
||||
|
||||
class AesEngine(EncryptionDecryptionBaseEngine):
|
||||
"""Provide AES encryption and decryption methods.
|
||||
|
||||
You may also consider using the AesGcmEngine instead -- that may be
|
||||
a better fit for some cases.
|
||||
|
||||
You should NOT use the AesGcmEngine if you want to be able to search
|
||||
for a row based on the value of an encrypted column. Use AesEngine
|
||||
instead, since that allows you to perform such searches.
|
||||
|
||||
If you don't need to search by the value of an encypted column, the
|
||||
AesGcmEngine provides better security.
|
||||
"""
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
|
||||
def _initialize_engine(self, parent_class_key):
|
||||
self.secret_key = parent_class_key
|
||||
self.iv = self.secret_key[:16]
|
||||
self.cipher = Cipher(
|
||||
algorithms.AES(self.secret_key),
|
||||
modes.CBC(self.iv),
|
||||
backend=default_backend()
|
||||
)
|
||||
|
||||
def _set_padding_mechanism(self, padding_mechanism=None):
|
||||
"""Set the padding mechanism."""
|
||||
if isinstance(padding_mechanism, str):
|
||||
if padding_mechanism not in PADDING_MECHANISM.keys():
|
||||
raise ImproperlyConfigured(
|
||||
"There is not padding mechanism with name {}".format(
|
||||
padding_mechanism
|
||||
)
|
||||
)
|
||||
|
||||
if padding_mechanism is None:
|
||||
padding_mechanism = 'naive'
|
||||
|
||||
padding_class = PADDING_MECHANISM[padding_mechanism]
|
||||
self.padding_engine = padding_class(self.BLOCK_SIZE)
|
||||
|
||||
def encrypt(self, value):
|
||||
if not isinstance(value, str):
|
||||
value = repr(value)
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.encode()
|
||||
value = self.padding_engine.pad(value)
|
||||
encryptor = self.cipher.encryptor()
|
||||
encrypted = encryptor.update(value) + encryptor.finalize()
|
||||
encrypted = base64.b64encode(encrypted)
|
||||
return encrypted.decode('utf-8')
|
||||
|
||||
def decrypt(self, value):
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
decryptor = self.cipher.decryptor()
|
||||
decrypted = base64.b64decode(value)
|
||||
decrypted = decryptor.update(decrypted) + decryptor.finalize()
|
||||
decrypted = self.padding_engine.unpad(decrypted)
|
||||
if not isinstance(decrypted, str):
|
||||
try:
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError('Invalid decryption key')
|
||||
return decrypted
|
||||
|
||||
|
||||
class AesGcmEngine(EncryptionDecryptionBaseEngine):
|
||||
"""Provide AES/GCM encryption and decryption methods.
|
||||
|
||||
You may also consider using the AesEngine instead -- that may be
|
||||
a better fit for some cases.
|
||||
|
||||
You should NOT use this AesGcmEngine if you want to be able to search
|
||||
for a row based on the value of an encrypted column. Use AesEngine
|
||||
instead, since that allows you to perform such searches.
|
||||
|
||||
If you don't need to search by the value of an encypted column, the
|
||||
AesGcmEngine provides better security.
|
||||
"""
|
||||
|
||||
BLOCK_SIZE = 16
|
||||
IV_BYTES_NEEDED = 12
|
||||
TAG_SIZE_BYTES = BLOCK_SIZE
|
||||
|
||||
def _initialize_engine(self, parent_class_key):
|
||||
self.secret_key = parent_class_key
|
||||
|
||||
def encrypt(self, value):
|
||||
if not isinstance(value, str):
|
||||
value = repr(value)
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.encode()
|
||||
iv = os.urandom(self.IV_BYTES_NEEDED)
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.secret_key),
|
||||
modes.GCM(iv),
|
||||
backend=default_backend()
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted = encryptor.update(value) + encryptor.finalize()
|
||||
assert len(encryptor.tag) == self.TAG_SIZE_BYTES
|
||||
encrypted = base64.b64encode(iv + encryptor.tag + encrypted)
|
||||
return encrypted.decode('utf-8')
|
||||
|
||||
def decrypt(self, value):
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
decrypted = base64.b64decode(value)
|
||||
if len(decrypted) < self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES:
|
||||
raise InvalidCiphertextError()
|
||||
iv = decrypted[:self.IV_BYTES_NEEDED]
|
||||
tag = decrypted[self.IV_BYTES_NEEDED:
|
||||
self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES]
|
||||
decrypted = decrypted[self.IV_BYTES_NEEDED + self.TAG_SIZE_BYTES:]
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.secret_key),
|
||||
modes.GCM(iv, tag),
|
||||
backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
try:
|
||||
decrypted = decryptor.update(decrypted) + decryptor.finalize()
|
||||
except InvalidTag:
|
||||
raise InvalidCiphertextError()
|
||||
if not isinstance(decrypted, str):
|
||||
try:
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
raise InvalidCiphertextError()
|
||||
return decrypted
|
||||
|
||||
|
||||
class FernetEngine(EncryptionDecryptionBaseEngine):
|
||||
"""Provide Fernet encryption and decryption methods."""
|
||||
|
||||
def _initialize_engine(self, parent_class_key):
|
||||
self.secret_key = base64.urlsafe_b64encode(parent_class_key)
|
||||
self.fernet = Fernet(self.secret_key)
|
||||
|
||||
def encrypt(self, value):
|
||||
if not isinstance(value, str):
|
||||
value = repr(value)
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
value = value.encode()
|
||||
encrypted = self.fernet.encrypt(value)
|
||||
return encrypted.decode('utf-8')
|
||||
|
||||
def decrypt(self, value):
|
||||
if isinstance(value, str):
|
||||
value = str(value)
|
||||
decrypted = self.fernet.decrypt(value.encode())
|
||||
if not isinstance(decrypted, str):
|
||||
decrypted = decrypted.decode('utf-8')
|
||||
return decrypted
|
||||
|
||||
|
||||
class StringEncryptedType(TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
EncryptedType provides a way to encrypt and decrypt values,
|
||||
to and from databases, that their type is a basic SQLAlchemy type.
|
||||
For example Unicode, String or even Boolean.
|
||||
On the way in, the value is encrypted and on the way out the stored value
|
||||
is decrypted.
|
||||
|
||||
EncryptedType needs Cryptography_ library in order to work.
|
||||
|
||||
When declaring a column which will be of type EncryptedType
|
||||
it is better to be as precise as possible and follow the pattern
|
||||
below.
|
||||
|
||||
.. _Cryptography: https://cryptography.io/en/latest/
|
||||
|
||||
::
|
||||
|
||||
|
||||
a_column = sa.Column(EncryptedType(sa.Unicode,
|
||||
secret_key,
|
||||
FernetEngine))
|
||||
|
||||
another_column = sa.Column(EncryptedType(sa.Unicode,
|
||||
secret_key,
|
||||
AesEngine,
|
||||
'pkcs5'))
|
||||
|
||||
|
||||
A more complete example is given below.
|
||||
|
||||
::
|
||||
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import create_engine
|
||||
try:
|
||||
from sqlalchemy.orm import declarative_base
|
||||
except ImportError:
|
||||
# sqlalchemy 1.3
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from sqlalchemy_utils import EncryptedType
|
||||
from sqlalchemy_utils.types.encrypted.encrypted_type import AesEngine
|
||||
|
||||
secret_key = 'secretkey1234'
|
||||
# setup
|
||||
engine = create_engine('sqlite:///:memory:')
|
||||
connection = engine.connect()
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = "user"
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
username = sa.Column(EncryptedType(sa.Unicode,
|
||||
secret_key,
|
||||
AesEngine,
|
||||
'pkcs5'))
|
||||
access_token = sa.Column(EncryptedType(sa.String,
|
||||
secret_key,
|
||||
AesEngine,
|
||||
'pkcs5'))
|
||||
is_active = sa.Column(EncryptedType(sa.Boolean,
|
||||
secret_key,
|
||||
AesEngine,
|
||||
'zeroes'))
|
||||
number_of_accounts = sa.Column(EncryptedType(sa.Integer,
|
||||
secret_key,
|
||||
AesEngine,
|
||||
'oneandzeroes'))
|
||||
|
||||
|
||||
sa.orm.configure_mappers()
|
||||
Base.metadata.create_all(connection)
|
||||
|
||||
# create a configured "Session" class
|
||||
Session = sessionmaker(bind=connection)
|
||||
|
||||
# create a Session
|
||||
session = Session()
|
||||
|
||||
# example
|
||||
user_name = 'secret_user'
|
||||
test_token = 'atesttoken'
|
||||
active = True
|
||||
num_of_accounts = 2
|
||||
|
||||
user = User(username=user_name, access_token=test_token,
|
||||
is_active=active, number_of_accounts=num_of_accounts)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user_id = user.id
|
||||
|
||||
session.expunge_all()
|
||||
|
||||
user_instance = session.query(User).get(user_id)
|
||||
|
||||
print('id: {}'.format(user_instance.id))
|
||||
print('username: {}'.format(user_instance.username))
|
||||
print('token: {}'.format(user_instance.access_token))
|
||||
print('active: {}'.format(user_instance.is_active))
|
||||
print('accounts: {}'.format(user_instance.number_of_accounts))
|
||||
|
||||
# teardown
|
||||
session.close_all()
|
||||
Base.metadata.drop_all(connection)
|
||||
connection.close()
|
||||
engine.dispose()
|
||||
|
||||
The key parameter accepts a callable to allow for the key to change
|
||||
per-row instead of being fixed for the whole table.
|
||||
|
||||
::
|
||||
|
||||
|
||||
def get_key():
|
||||
return 'dynamic-key'
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
username = sa.Column(EncryptedType(
|
||||
sa.Unicode, get_key))
|
||||
|
||||
"""
|
||||
impl = String
|
||||
cache_ok = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
type_in=None,
|
||||
key=None,
|
||||
engine=None,
|
||||
padding=None,
|
||||
**kwargs
|
||||
):
|
||||
"""Initialization."""
|
||||
if not cryptography:
|
||||
raise ImproperlyConfigured(
|
||||
"'cryptography' is required to use EncryptedType"
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
# set the underlying type
|
||||
if type_in is None:
|
||||
type_in = String()
|
||||
elif isinstance(type_in, type):
|
||||
type_in = type_in()
|
||||
self.underlying_type = type_in
|
||||
self._key = key
|
||||
if not engine:
|
||||
engine = AesEngine
|
||||
self.engine = engine()
|
||||
if isinstance(self.engine, AesEngine):
|
||||
self.engine._set_padding_mechanism(padding)
|
||||
|
||||
@property
|
||||
def key(self):
|
||||
return self._key
|
||||
|
||||
@key.setter
|
||||
def key(self, value):
|
||||
self._key = value
|
||||
|
||||
def _update_key(self):
|
||||
key = self._key() if callable(self._key) else self._key
|
||||
self.engine._update_key(key)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
"""Encrypt a value on the way in."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
|
||||
try:
|
||||
value = self.underlying_type.process_bind_param(
|
||||
value, dialect
|
||||
)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_bind_param'
|
||||
|
||||
# Handle 'boolean' and 'dates'
|
||||
type_ = self.underlying_type.python_type
|
||||
if issubclass(type_, bool):
|
||||
value = 'true' if value else 'false'
|
||||
|
||||
elif issubclass(type_, (datetime.date, datetime.time)):
|
||||
value = value.isoformat()
|
||||
|
||||
elif issubclass(type_, JSONType):
|
||||
value = json.dumps(value)
|
||||
|
||||
return self.engine.encrypt(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
"""Decrypt value on the way out."""
|
||||
if value is not None:
|
||||
self._update_key()
|
||||
decrypted_value = self.engine.decrypt(value)
|
||||
|
||||
try:
|
||||
return self.underlying_type.process_result_value(
|
||||
decrypted_value, dialect
|
||||
)
|
||||
|
||||
except AttributeError:
|
||||
# Doesn't have 'process_result_value'
|
||||
|
||||
# Handle 'boolean' and 'dates'
|
||||
type_ = self.underlying_type.python_type
|
||||
date_types = [datetime.datetime, datetime.time, datetime.date]
|
||||
|
||||
if issubclass(type_, bool):
|
||||
return decrypted_value == 'true'
|
||||
|
||||
elif type_ in date_types:
|
||||
return DatetimeHandler.process_value(
|
||||
decrypted_value, type_
|
||||
)
|
||||
|
||||
elif issubclass(type_, JSONType):
|
||||
return json.loads(decrypted_value)
|
||||
|
||||
# Handle all others
|
||||
return self.underlying_type.python_type(decrypted_value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if isinstance(self.underlying_type, ScalarCoercible):
|
||||
return self.underlying_type._coerce(value)
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class EncryptedType(StringEncryptedType):
|
||||
impl = LargeBinary
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
warnings.warn(
|
||||
"The 'EncryptedType' class will change implementation from "
|
||||
"'LargeBinary' to 'String' in a future version. Use "
|
||||
"'StringEncryptedType' to use the 'String' implementation.",
|
||||
DeprecationWarning, stacklevel=2)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
value = super().process_bind_param(value=value, dialect=dialect)
|
||||
if isinstance(value, str):
|
||||
value = value.encode()
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if isinstance(value, bytes):
|
||||
value = value.decode()
|
||||
value = super().process_result_value(value=value, dialect=dialect)
|
||||
return value
|
||||
|
||||
|
||||
class DatetimeHandler:
|
||||
"""
|
||||
DatetimeHandler is responsible for parsing strings and
|
||||
returning the appropriate date, datetime or time objects.
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def process_value(cls, value, python_type):
|
||||
"""
|
||||
process_value returns a datetime, date
|
||||
or time object according to a given string
|
||||
value and a python type.
|
||||
"""
|
||||
if not dateutil:
|
||||
raise ImproperlyConfigured(
|
||||
"'python-dateutil' is required to process datetimes"
|
||||
)
|
||||
|
||||
return_value = datetime_parse(value)
|
||||
|
||||
if issubclass(python_type, datetime.datetime):
|
||||
return return_value
|
||||
elif issubclass(python_type, datetime.time):
|
||||
return return_value.time()
|
||||
elif issubclass(python_type, datetime.date):
|
||||
return return_value.date()
|
|
@ -0,0 +1,142 @@
|
|||
class InvalidPaddingError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class Padding:
|
||||
"""Base class for padding and unpadding."""
|
||||
|
||||
def __init__(self, block_size):
|
||||
self.block_size = block_size
|
||||
|
||||
def pad(self, value):
|
||||
raise NotImplementedError('Subclasses must implement this!')
|
||||
|
||||
def unpad(self, value):
|
||||
raise NotImplementedError('Subclasses must implement this!')
|
||||
|
||||
|
||||
class PKCS5Padding(Padding):
|
||||
"""Provide PKCS5 padding and unpadding."""
|
||||
|
||||
def pad(self, value):
|
||||
if not isinstance(value, bytes):
|
||||
value = value.encode()
|
||||
padding_length = (self.block_size - len(value) % self.block_size)
|
||||
padding_sequence = padding_length * bytes((padding_length,))
|
||||
value_with_padding = value + padding_sequence
|
||||
|
||||
return value_with_padding
|
||||
|
||||
def unpad(self, value):
|
||||
# Perform some input validations.
|
||||
# In case of error, we throw a generic InvalidPaddingError()
|
||||
if not value or len(value) < self.block_size:
|
||||
# PKCS5 padded output will always be at least 1 block size
|
||||
raise InvalidPaddingError()
|
||||
if len(value) % self.block_size != 0:
|
||||
# PKCS5 padded output will be a multiple of the block size
|
||||
raise InvalidPaddingError()
|
||||
if isinstance(value, bytes):
|
||||
padding_length = value[-1]
|
||||
if isinstance(value, str):
|
||||
padding_length = ord(value[-1])
|
||||
if padding_length == 0 or padding_length > self.block_size:
|
||||
raise InvalidPaddingError()
|
||||
|
||||
def convert_byte_or_char_to_number(x):
|
||||
return ord(x) if isinstance(x, str) else x
|
||||
if any([padding_length != convert_byte_or_char_to_number(x)
|
||||
for x in value[-padding_length:]]):
|
||||
raise InvalidPaddingError()
|
||||
|
||||
value_without_padding = value[0:-padding_length]
|
||||
|
||||
return value_without_padding
|
||||
|
||||
|
||||
class OneAndZeroesPadding(Padding):
|
||||
"""Provide the one and zeroes padding and unpadding.
|
||||
|
||||
This mechanism pads with 0x80 followed by zero bytes.
|
||||
For unpadding it strips off all trailing zero bytes and the 0x80 byte.
|
||||
"""
|
||||
|
||||
BYTE_80 = 0x80
|
||||
BYTE_00 = 0x00
|
||||
|
||||
def pad(self, value):
|
||||
if not isinstance(value, bytes):
|
||||
value = value.encode()
|
||||
padding_length = (self.block_size - len(value) % self.block_size)
|
||||
one_part_bytes = bytes((self.BYTE_80,))
|
||||
zeroes_part_bytes = (padding_length - 1) * bytes((self.BYTE_00,))
|
||||
padding_sequence = one_part_bytes + zeroes_part_bytes
|
||||
value_with_padding = value + padding_sequence
|
||||
|
||||
return value_with_padding
|
||||
|
||||
def unpad(self, value):
|
||||
value_without_padding = value.rstrip(bytes((self.BYTE_00,)))
|
||||
value_without_padding = value_without_padding.rstrip(
|
||||
bytes((self.BYTE_80,)))
|
||||
|
||||
return value_without_padding
|
||||
|
||||
|
||||
class ZeroesPadding(Padding):
|
||||
"""Provide zeroes padding and unpadding.
|
||||
|
||||
This mechanism pads with 0x00 except the last byte equals
|
||||
to the padding length. For unpadding it reads the last byte
|
||||
and strips off that many bytes.
|
||||
"""
|
||||
|
||||
BYTE_00 = 0x00
|
||||
|
||||
def pad(self, value):
|
||||
if not isinstance(value, bytes):
|
||||
value = value.encode()
|
||||
padding_length = (self.block_size - len(value) % self.block_size)
|
||||
zeroes_part_bytes = (padding_length - 1) * bytes((self.BYTE_00,))
|
||||
last_part_bytes = bytes((padding_length,))
|
||||
padding_sequence = zeroes_part_bytes + last_part_bytes
|
||||
value_with_padding = value + padding_sequence
|
||||
|
||||
return value_with_padding
|
||||
|
||||
def unpad(self, value):
|
||||
if isinstance(value, bytes):
|
||||
padding_length = value[-1]
|
||||
if isinstance(value, str):
|
||||
padding_length = ord(value[-1])
|
||||
value_without_padding = value[0:-padding_length]
|
||||
|
||||
return value_without_padding
|
||||
|
||||
|
||||
class NaivePadding(Padding):
|
||||
"""Naive padding and unpadding using '*'.
|
||||
|
||||
The class is provided only for backwards compatibility.
|
||||
"""
|
||||
|
||||
CHARACTER = b'*'
|
||||
|
||||
def pad(self, value):
|
||||
num_of_bytes = (self.block_size - len(value) % self.block_size)
|
||||
value_with_padding = value + num_of_bytes * self.CHARACTER
|
||||
|
||||
return value_with_padding
|
||||
|
||||
def unpad(self, value):
|
||||
value_without_padding = value.rstrip(self.CHARACTER)
|
||||
|
||||
return value_without_padding
|
||||
|
||||
|
||||
PADDING_MECHANISM = {
|
||||
'pkcs5': PKCS5Padding,
|
||||
'oneandzeroes': OneAndZeroesPadding,
|
||||
'zeroes': ZeroesPadding,
|
||||
'naive': NaivePadding
|
||||
}
|
|
@ -0,0 +1,4 @@
|
|||
# Module for enriched date, datetime type
|
||||
from .arrow_datetime import ArrowDateTime # noqa
|
||||
from .pendulum_date import PendulumDate # noqa
|
||||
from .pendulum_datetime import PendulumDateTime # noqa
|
|
@ -0,0 +1,39 @@
|
|||
from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
|
||||
from ...exceptions import ImproperlyConfigured
|
||||
|
||||
arrow = None
|
||||
try:
|
||||
import arrow
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class ArrowDateTime:
|
||||
def __init__(self):
|
||||
if not arrow:
|
||||
raise ImproperlyConfigured(
|
||||
"'arrow' package is required to use 'ArrowDateTime'"
|
||||
)
|
||||
|
||||
def _coerce(self, impl, value):
|
||||
if isinstance(value, str):
|
||||
value = arrow.get(value)
|
||||
elif isinstance(value, Iterable):
|
||||
value = arrow.get(*value)
|
||||
elif isinstance(value, datetime):
|
||||
value = arrow.get(value)
|
||||
return value
|
||||
|
||||
def process_bind_param(self, impl, value, dialect):
|
||||
if value:
|
||||
utc_val = self._coerce(impl, value).to('UTC')
|
||||
return utc_val.datetime\
|
||||
if impl.timezone else utc_val.naive
|
||||
return value
|
||||
|
||||
def process_result_value(self, impl, value, dialect):
|
||||
if value:
|
||||
return arrow.get(value)
|
||||
return value
|
|
@ -0,0 +1,50 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..scalar_coercible import ScalarCoercible
|
||||
from .pendulum_date import PendulumDate
|
||||
|
||||
|
||||
class EnrichedDateType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
Supported for pendulum only.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import EnrichedDateType
|
||||
import pendulum
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
birthday = sa.Column(EnrichedDateType(type="pendulum"))
|
||||
|
||||
|
||||
user = User()
|
||||
user.birthday = pendulum.datetime(year=1995, month=7, day=11)
|
||||
session.add(user)
|
||||
session.commit()
|
||||
"""
|
||||
impl = types.Date
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, date_processor=PendulumDate, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.date_object = date_processor()
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.date_object._coerce(self.impl, value)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self.date_object.process_bind_param(self.impl, value, dialect)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self.date_object.process_result_value(self.impl, value, dialect)
|
||||
|
||||
def process_literal_param(self, value, dialect):
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
|
@ -0,0 +1,51 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..scalar_coercible import ScalarCoercible
|
||||
from .pendulum_datetime import PendulumDateTime
|
||||
|
||||
|
||||
class EnrichedDateTimeType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
Supported for arrow and pendulum.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import EnrichedDateTimeType
|
||||
import pendulum
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
created_at = sa.Column(EnrichedDateTimeType(type="pendulum"))
|
||||
# created_at = sa.Column(EnrichedDateTimeType(type="arrow"))
|
||||
|
||||
|
||||
user = User()
|
||||
user.created_at = pendulum.now()
|
||||
session.add(user)
|
||||
session.commit()
|
||||
"""
|
||||
impl = types.DateTime
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, datetime_processor=PendulumDateTime, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.dt_object = datetime_processor()
|
||||
|
||||
def _coerce(self, value):
|
||||
return self.dt_object._coerce(self.impl, value)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self.dt_object.process_bind_param(self.impl, value, dialect)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self.dt_object.process_result_value(self.impl, value, dialect)
|
||||
|
||||
def process_literal_param(self, value, dialect):
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
|
@ -0,0 +1,32 @@
|
|||
from ...exceptions import ImproperlyConfigured
|
||||
from .pendulum_datetime import PendulumDateTime
|
||||
|
||||
pendulum = None
|
||||
try:
|
||||
import pendulum
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class PendulumDate(PendulumDateTime):
|
||||
def __init__(self):
|
||||
if not pendulum:
|
||||
raise ImproperlyConfigured(
|
||||
"'pendulum' package is required to use 'PendulumDate'"
|
||||
)
|
||||
|
||||
def _coerce(self, impl, value):
|
||||
if value:
|
||||
if not isinstance(value, pendulum.Date):
|
||||
value = super()._coerce(impl, value).date()
|
||||
return value
|
||||
|
||||
def process_result_value(self, impl, value, dialect):
|
||||
if value:
|
||||
return pendulum.parse(value.isoformat()).date()
|
||||
return value
|
||||
|
||||
def process_bind_param(self, impl, value, dialect):
|
||||
if value:
|
||||
return self._coerce(impl, value)
|
||||
return value
|
|
@ -0,0 +1,49 @@
|
|||
from datetime import datetime
|
||||
|
||||
from ...exceptions import ImproperlyConfigured
|
||||
|
||||
pendulum = None
|
||||
try:
|
||||
import pendulum
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class PendulumDateTime:
|
||||
def __init__(self):
|
||||
if not pendulum:
|
||||
raise ImproperlyConfigured(
|
||||
"'pendulum' package is required to use 'PendulumDateTime'"
|
||||
)
|
||||
|
||||
def _coerce(self, impl, value):
|
||||
if value is not None:
|
||||
if isinstance(value, pendulum.DateTime):
|
||||
pass
|
||||
elif isinstance(value, (int, float)):
|
||||
value = pendulum.from_timestamp(value)
|
||||
elif isinstance(value, str) and value.isdigit():
|
||||
value = pendulum.from_timestamp(int(value))
|
||||
elif isinstance(value, datetime):
|
||||
value = pendulum.datetime(
|
||||
value.year,
|
||||
value.month,
|
||||
value.day,
|
||||
value.hour,
|
||||
value.minute,
|
||||
value.second,
|
||||
value.microsecond
|
||||
)
|
||||
else:
|
||||
value = pendulum.parse(value)
|
||||
return value
|
||||
|
||||
def process_bind_param(self, impl, value, dialect):
|
||||
if value:
|
||||
return self._coerce(impl, value).in_tz('UTC')
|
||||
return value
|
||||
|
||||
def process_result_value(self, impl, value, dialect):
|
||||
if value:
|
||||
return pendulum.parse(value.isoformat())
|
||||
return value
|
|
@ -0,0 +1,52 @@
|
|||
from ipaddress import ip_address
|
||||
|
||||
from sqlalchemy import types
|
||||
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class IPAddressType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
Changes IPAddress objects to a string representation on the way in and
|
||||
changes them back to IPAddress objects on the way out.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import IPAddressType
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
ip_address = sa.Column(IPAddressType)
|
||||
|
||||
|
||||
user = User()
|
||||
user.ip_address = '123.123.123.123'
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
user.ip_address # IPAddress object
|
||||
"""
|
||||
|
||||
impl = types.Unicode(50)
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, max_length=50, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.impl = types.Unicode(max_length)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return str(value) if value else None
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return ip_address(value) if value else None
|
||||
|
||||
def _coerce(self, value):
|
||||
return ip_address(value) if value else None
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
|
@ -0,0 +1,77 @@
|
|||
import json
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names
|
||||
|
||||
try:
|
||||
from sqlalchemy.dialects.postgresql import JSON
|
||||
has_postgres_json = True
|
||||
except ImportError:
|
||||
class PostgresJSONType(sa.types.UserDefinedType):
|
||||
"""
|
||||
Text search vector type for postgresql.
|
||||
"""
|
||||
def get_col_spec(self):
|
||||
return 'json'
|
||||
|
||||
ischema_names['json'] = PostgresJSONType
|
||||
has_postgres_json = False
|
||||
|
||||
|
||||
class JSONType(sa.types.TypeDecorator):
|
||||
"""
|
||||
JSONType offers way of saving JSON data structures to database. On
|
||||
PostgreSQL the underlying implementation of this data type is 'json' while
|
||||
on other databases its simply 'text'.
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import JSONType
|
||||
|
||||
|
||||
class Product(Base):
|
||||
__tablename__ = 'product'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(50))
|
||||
details = sa.Column(JSONType)
|
||||
|
||||
|
||||
product = Product()
|
||||
product.details = {
|
||||
'color': 'red',
|
||||
'type': 'car',
|
||||
'max-speed': '400 mph'
|
||||
}
|
||||
session.commit()
|
||||
"""
|
||||
impl = sa.UnicodeText
|
||||
hashable = False
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
# Use the native JSON type.
|
||||
if has_postgres_json:
|
||||
return dialect.type_descriptor(JSON())
|
||||
else:
|
||||
return dialect.type_descriptor(PostgresJSONType())
|
||||
else:
|
||||
return dialect.type_descriptor(self.impl)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if dialect.name == 'postgresql' and has_postgres_json:
|
||||
return value
|
||||
if value is not None:
|
||||
value = json.dumps(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
return value
|
||||
if value is not None:
|
||||
value = json.loads(value)
|
||||
return value
|
|
@ -0,0 +1,76 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
babel = None
|
||||
try:
|
||||
import babel
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class LocaleType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
LocaleType saves Babel_ Locale objects into database. The Locale objects
|
||||
are converted to string on the way in and back to object on the way out.
|
||||
|
||||
In order to use LocaleType you need to install Babel_ first.
|
||||
|
||||
.. _Babel: https://babel.pocoo.org/
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import LocaleType
|
||||
from babel import Locale
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(50))
|
||||
locale = sa.Column(LocaleType)
|
||||
|
||||
|
||||
user = User()
|
||||
user.locale = Locale('en_US')
|
||||
session.add(user)
|
||||
session.commit()
|
||||
|
||||
|
||||
Like many other types this type also supports scalar coercion:
|
||||
|
||||
::
|
||||
|
||||
|
||||
user.locale = 'de_DE'
|
||||
user.locale # Locale('de', territory='DE')
|
||||
|
||||
"""
|
||||
|
||||
impl = types.Unicode(10)
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self):
|
||||
if babel is None:
|
||||
raise ImproperlyConfigured(
|
||||
'Babel packaged is required with LocaleType.'
|
||||
)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, babel.Locale):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return babel.Locale.parse(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, babel.Locale):
|
||||
return babel.Locale.parse(value)
|
||||
return value
|
|
@ -0,0 +1,121 @@
|
|||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
from sqlalchemy.dialects.postgresql.base import ischema_names, PGTypeCompiler
|
||||
from sqlalchemy.sql import expression
|
||||
|
||||
from ..primitives import Ltree
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class LtreeType(types.Concatenable, types.UserDefinedType, ScalarCoercible):
|
||||
"""Postgresql LtreeType type.
|
||||
|
||||
The LtreeType datatype can be used for representing labels of data stored
|
||||
in hierarchical tree-like structure. For more detailed information please
|
||||
refer to https://www.postgresql.org/docs/current/ltree.html
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import LtreeType, Ltree
|
||||
|
||||
|
||||
class DocumentSection(Base):
|
||||
__tablename__ = 'document_section'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
path = sa.Column(LtreeType)
|
||||
|
||||
|
||||
section = DocumentSection(path=Ltree('Countries.Finland'))
|
||||
session.add(section)
|
||||
session.commit()
|
||||
|
||||
section.path # Ltree('Countries.Finland')
|
||||
|
||||
|
||||
.. note::
|
||||
Using :class:`LtreeType`, :class:`LQUERY` and :class:`LTXTQUERY` types
|
||||
may require installation of Postgresql ltree extension on the server
|
||||
side. Please visit https://www.postgresql.org/ for details.
|
||||
"""
|
||||
cache_ok = True
|
||||
|
||||
class comparator_factory(types.Concatenable.Comparator):
|
||||
def ancestor_of(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('@>')(expression.cast(other, ARRAY(LtreeType)))
|
||||
else:
|
||||
return self.op('@>')(other)
|
||||
|
||||
def descendant_of(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('<@')(expression.cast(other, ARRAY(LtreeType)))
|
||||
else:
|
||||
return self.op('<@')(other)
|
||||
|
||||
def lquery(self, other):
|
||||
if isinstance(other, list):
|
||||
return self.op('?')(expression.cast(other, ARRAY(LQUERY)))
|
||||
else:
|
||||
return self.op('~')(other)
|
||||
|
||||
def ltxtquery(self, other):
|
||||
return self.op('@')(other)
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value:
|
||||
return value.path
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
return self._coerce(value)
|
||||
return process
|
||||
|
||||
def literal_processor(self, dialect):
|
||||
def process(value):
|
||||
value = value.replace("'", "''")
|
||||
return "'%s'" % value
|
||||
return process
|
||||
|
||||
__visit_name__ = 'LTREE'
|
||||
|
||||
def _coerce(self, value):
|
||||
if value:
|
||||
return Ltree(value)
|
||||
|
||||
|
||||
class LQUERY(types.TypeEngine):
|
||||
"""Postresql LQUERY type.
|
||||
See :class:`LTREE` for details.
|
||||
"""
|
||||
__visit_name__ = 'LQUERY'
|
||||
|
||||
|
||||
class LTXTQUERY(types.TypeEngine):
|
||||
"""Postresql LTXTQUERY type.
|
||||
See :class:`LTREE` for details.
|
||||
"""
|
||||
__visit_name__ = 'LTXTQUERY'
|
||||
|
||||
|
||||
ischema_names['ltree'] = LtreeType
|
||||
ischema_names['lquery'] = LQUERY
|
||||
ischema_names['ltxtquery'] = LTXTQUERY
|
||||
|
||||
|
||||
def visit_LTREE(self, type_, **kw):
|
||||
return 'LTREE'
|
||||
|
||||
|
||||
def visit_LQUERY(self, type_, **kw):
|
||||
return 'LQUERY'
|
||||
|
||||
|
||||
def visit_LTXTQUERY(self, type_, **kw):
|
||||
return 'LTXTQUERY'
|
||||
|
||||
|
||||
PGTypeCompiler.visit_LTREE = visit_LTREE
|
||||
PGTypeCompiler.visit_LQUERY = visit_LQUERY
|
||||
PGTypeCompiler.visit_LTXTQUERY = visit_LTXTQUERY
|
|
@ -0,0 +1,259 @@
|
|||
import weakref
|
||||
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects import oracle, postgresql, sqlite
|
||||
from sqlalchemy.ext.mutable import Mutable
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
passlib = None
|
||||
try:
|
||||
import passlib
|
||||
from passlib.context import LazyCryptContext
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Password(Mutable):
|
||||
|
||||
@classmethod
|
||||
def coerce(cls, key, value):
|
||||
if isinstance(value, Password):
|
||||
return value
|
||||
|
||||
if isinstance(value, (str, bytes)):
|
||||
return cls(value, secret=True)
|
||||
|
||||
super().coerce(key, value)
|
||||
|
||||
def __init__(self, value, context=None, secret=False):
|
||||
# Store the hash (if it is one).
|
||||
self.hash = value if not secret else None
|
||||
|
||||
# Store the secret if we have one.
|
||||
self.secret = value if secret else None
|
||||
|
||||
# The hash should be bytes.
|
||||
if isinstance(self.hash, str):
|
||||
self.hash = self.hash.encode('utf8')
|
||||
|
||||
# Save weakref of the password context (if we have one)
|
||||
self.context = weakref.proxy(context) if context is not None else None
|
||||
|
||||
def __eq__(self, value):
|
||||
if self.hash is None or value is None:
|
||||
# Ensure that we don't continue comparison if one of us is None.
|
||||
return self.hash is value
|
||||
|
||||
if isinstance(value, Password):
|
||||
# Comparing 2 hashes isn't very useful; but this equality
|
||||
# method breaks otherwise.
|
||||
return value.hash == self.hash
|
||||
|
||||
if self.context is None:
|
||||
# Compare 2 hashes again as we don't know how to validate.
|
||||
return value == self
|
||||
|
||||
if isinstance(value, (str, bytes)):
|
||||
valid, new = self.context.verify_and_update(value, self.hash)
|
||||
if valid and new:
|
||||
# New hash was calculated due to various reasons; stored one
|
||||
# wasn't optimal, etc.
|
||||
self.hash = new
|
||||
|
||||
# The hash should be bytes.
|
||||
if isinstance(self.hash, str):
|
||||
self.hash = self.hash.encode('utf8')
|
||||
self.changed()
|
||||
|
||||
return valid
|
||||
|
||||
return False
|
||||
|
||||
def __ne__(self, value):
|
||||
return not (self == value)
|
||||
|
||||
|
||||
class PasswordType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
PasswordType hashes passwords as they come into the database and allows
|
||||
verifying them using a Pythonic interface. This Pythonic interface
|
||||
relies on setting up automatic data type coercion using the
|
||||
:func:`~sqlalchemy_utils.listeners.force_auto_coercion` function.
|
||||
|
||||
All keyword arguments (aside from max_length) are forwarded to the
|
||||
construction of a `passlib.context.LazyCryptContext` object, which
|
||||
also supports deferred configuration via the `onload` callback.
|
||||
|
||||
The following usage will create a password column that will
|
||||
automatically hash new passwords as `pbkdf2_sha512` but still compare
|
||||
passwords against pre-existing `md5_crypt` hashes. As passwords are
|
||||
compared; the password hash in the database will be updated to
|
||||
be `pbkdf2_sha512`.
|
||||
|
||||
::
|
||||
|
||||
|
||||
class Model(Base):
|
||||
password = sa.Column(PasswordType(
|
||||
schemes=[
|
||||
'pbkdf2_sha512',
|
||||
'md5_crypt'
|
||||
],
|
||||
|
||||
deprecated=['md5_crypt']
|
||||
))
|
||||
|
||||
|
||||
Verifying password is as easy as:
|
||||
|
||||
::
|
||||
|
||||
target = Model()
|
||||
target.password = 'b'
|
||||
# '$5$rounds=80000$H.............'
|
||||
|
||||
target.password == 'b'
|
||||
# True
|
||||
|
||||
|
||||
Lazy configuration of the type with Flask config:
|
||||
|
||||
::
|
||||
|
||||
|
||||
import flask
|
||||
from sqlalchemy_utils import PasswordType, force_auto_coercion
|
||||
|
||||
force_auto_coercion()
|
||||
|
||||
class User(db.Model):
|
||||
__tablename__ = 'user'
|
||||
|
||||
password = db.Column(
|
||||
PasswordType(
|
||||
# The returned dictionary is forwarded to the CryptContext
|
||||
onload=lambda **kwargs: dict(
|
||||
schemes=flask.current_app.config['PASSWORD_SCHEMES'],
|
||||
**kwargs
|
||||
),
|
||||
),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
"""
|
||||
|
||||
impl = types.VARBINARY(1024)
|
||||
python_type = Password
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, max_length=None, **kwargs):
|
||||
# Fail if passlib is not found.
|
||||
if passlib is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'passlib' is required to use 'PasswordType'"
|
||||
)
|
||||
|
||||
# Construct the passlib crypt context.
|
||||
self.context = LazyCryptContext(**kwargs)
|
||||
self._max_length = max_length
|
||||
|
||||
@property
|
||||
def hashing_method(self):
|
||||
return (
|
||||
'hash'
|
||||
if hasattr(self.context, 'hash')
|
||||
else 'encrypt'
|
||||
)
|
||||
|
||||
@property
|
||||
def length(self):
|
||||
"""Get column length."""
|
||||
if self._max_length is None:
|
||||
self._max_length = self.calculate_max_length()
|
||||
|
||||
return self._max_length
|
||||
|
||||
def calculate_max_length(self):
|
||||
# Calculate the largest possible encoded password.
|
||||
# name + rounds + salt + hash + ($ * 4) of largest hash
|
||||
max_lengths = [1024]
|
||||
for name in self.context.schemes():
|
||||
scheme = getattr(__import__('passlib.hash').hash, name)
|
||||
length = 4 + len(scheme.name)
|
||||
length += len(str(getattr(scheme, 'max_rounds', '')))
|
||||
length += (getattr(scheme, 'max_salt_size', 0) or 0)
|
||||
length += getattr(
|
||||
scheme,
|
||||
'encoded_checksum_size',
|
||||
scheme.checksum_size
|
||||
)
|
||||
max_lengths.append(length)
|
||||
|
||||
# Return the maximum calculated max length.
|
||||
return max(max_lengths)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
# Use a BYTEA type for postgresql.
|
||||
impl = postgresql.BYTEA(self.length)
|
||||
elif dialect.name == 'oracle':
|
||||
# Use a RAW type for oracle.
|
||||
impl = oracle.RAW(self.length)
|
||||
elif dialect.name == 'sqlite':
|
||||
# Use a BLOB type for sqlite
|
||||
impl = sqlite.BLOB(self.length)
|
||||
else:
|
||||
# Use a VARBINARY for all other dialects.
|
||||
impl = types.VARBINARY(self.length)
|
||||
return dialect.type_descriptor(impl)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, Password):
|
||||
# If were given a password secret; hash it.
|
||||
if value.secret is not None:
|
||||
return self._hash(value.secret).encode('utf8')
|
||||
|
||||
# Value has already been hashed.
|
||||
return value.hash
|
||||
|
||||
if isinstance(value, str):
|
||||
# Assume value has not been hashed.
|
||||
return self._hash(value).encode('utf8')
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return Password(value, self.context)
|
||||
|
||||
def _hash(self, value):
|
||||
return getattr(self.context, self.hashing_method)(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
|
||||
if value is None:
|
||||
return
|
||||
|
||||
if not isinstance(value, Password):
|
||||
# Hash the password using the default scheme.
|
||||
value = self._hash(value).encode('utf8')
|
||||
return Password(value, context=self.context)
|
||||
|
||||
else:
|
||||
# If were given a password object; ensure the context is right.
|
||||
value.context = weakref.proxy(self.context)
|
||||
|
||||
# If were given a password secret; hash it.
|
||||
if value.secret is not None:
|
||||
value.hash = self._hash(value.secret).encode('utf8')
|
||||
value.secret = None
|
||||
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
||||
|
||||
|
||||
Password.associate_with(PasswordType)
|
|
@ -0,0 +1,390 @@
|
|||
"""
|
||||
CompositeType provides means to interact with
|
||||
`PostgreSQL composite types`_. Currently this type features:
|
||||
|
||||
* Easy attribute access to composite type fields
|
||||
* Supports SQLAlchemy TypeDecorator types
|
||||
* Ability to include composite types as part of PostgreSQL arrays
|
||||
* Type creation and dropping
|
||||
|
||||
Installation
|
||||
^^^^^^^^^^^^
|
||||
|
||||
CompositeType automatically attaches `before_create` and `after_drop` DDL
|
||||
listeners. These listeners create and drop the composite type in the
|
||||
database. This means it works out of the box in your test environment where
|
||||
you create the tables on each test run.
|
||||
|
||||
When you already have your database set up you should call
|
||||
:func:`register_composites` after you've set up all models.
|
||||
|
||||
::
|
||||
|
||||
register_composites(conn)
|
||||
|
||||
|
||||
|
||||
Usage
|
||||
^^^^^
|
||||
|
||||
::
|
||||
|
||||
from collections import OrderedDict
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy_utils import CompositeType, CurrencyType
|
||||
|
||||
|
||||
class Account(Base):
|
||||
__tablename__ = 'account'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
balance = sa.Column(
|
||||
CompositeType(
|
||||
'money_type',
|
||||
[
|
||||
sa.Column('currency', CurrencyType),
|
||||
sa.Column('amount', sa.Integer)
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
Creation
|
||||
~~~~~~~~
|
||||
When creating CompositeType, you can either pass in a tuple or a dictionary.
|
||||
|
||||
::
|
||||
account1 = Account()
|
||||
account1.balance = ('USD', 15)
|
||||
|
||||
account2 = Account()
|
||||
account2.balance = {'currency': 'USD', 'amount': 15}
|
||||
|
||||
session.add(account1)
|
||||
session.add(account2)
|
||||
session.commit()
|
||||
|
||||
|
||||
Accessing fields
|
||||
^^^^^^^^^^^^^^^^
|
||||
|
||||
CompositeType provides attribute access to underlying fields. In the following
|
||||
example we find all accounts with balance amount more than 5000.
|
||||
|
||||
|
||||
::
|
||||
|
||||
session.query(Account).filter(Account.balance.amount > 5000)
|
||||
|
||||
|
||||
Arrays of composites
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy.dialects.postgresql import ARRAY
|
||||
|
||||
|
||||
class Account(Base):
|
||||
__tablename__ = 'account'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
balances = sa.Column(
|
||||
ARRAY(
|
||||
CompositeType(
|
||||
'money_type',
|
||||
[
|
||||
sa.Column('currency', CurrencyType),
|
||||
sa.Column('amount', sa.Integer)
|
||||
]
|
||||
),
|
||||
dimensions=1
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
.. _PostgreSQL composite types:
|
||||
https://www.postgresql.org/docs/current/rowtypes.html
|
||||
|
||||
|
||||
Related links:
|
||||
|
||||
https://schinckel.net/2014/09/24/using-postgres-composite-types-in-django/
|
||||
"""
|
||||
from collections import namedtuple
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql.psycopg2 import PGDialect_psycopg2
|
||||
from sqlalchemy.ext.compiler import compiles
|
||||
from sqlalchemy.schema import _CreateDropBase
|
||||
from sqlalchemy.sql.expression import FunctionElement
|
||||
from sqlalchemy.types import (
|
||||
SchemaType,
|
||||
to_instance,
|
||||
TypeDecorator,
|
||||
UserDefinedType
|
||||
)
|
||||
|
||||
from .. import ImproperlyConfigured
|
||||
|
||||
psycopg2 = None
|
||||
CompositeCaster = None
|
||||
adapt = None
|
||||
AsIs = None
|
||||
register_adapter = None
|
||||
try:
|
||||
import psycopg2
|
||||
from psycopg2.extensions import adapt, AsIs, register_adapter
|
||||
from psycopg2.extras import CompositeCaster
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class CompositeElement(FunctionElement):
|
||||
"""
|
||||
Instances of this class wrap a Postgres composite type.
|
||||
"""
|
||||
def __init__(self, base, field, type_):
|
||||
self.name = field
|
||||
self.type = to_instance(type_)
|
||||
|
||||
super().__init__(base)
|
||||
|
||||
|
||||
@compiles(CompositeElement)
|
||||
def _compile_pgelem(expr, compiler, **kw):
|
||||
return f'({compiler.process(expr.clauses, **kw)}).{expr.name}'
|
||||
|
||||
|
||||
# TODO: Make the registration work on connection level instead of global level
|
||||
registered_composites = {}
|
||||
|
||||
|
||||
class CompositeType(UserDefinedType, SchemaType):
|
||||
"""
|
||||
Represents a PostgreSQL composite type.
|
||||
|
||||
:param name:
|
||||
Name of the composite type.
|
||||
:param columns:
|
||||
List of columns that this composite type consists of
|
||||
"""
|
||||
python_type = tuple
|
||||
|
||||
class comparator_factory(UserDefinedType.Comparator):
|
||||
def __getattr__(self, key):
|
||||
try:
|
||||
type_ = self.type.typemap[key]
|
||||
except KeyError:
|
||||
raise KeyError(
|
||||
"Type '{}' doesn't have an attribute: '{}'".format(
|
||||
self.name, key
|
||||
)
|
||||
)
|
||||
|
||||
return CompositeElement(self.expr, key, type_)
|
||||
|
||||
def __init__(self, name, columns, quote=None, **kwargs):
|
||||
if psycopg2 is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'psycopg2' package is required in order to use CompositeType."
|
||||
)
|
||||
SchemaType.__init__(
|
||||
self,
|
||||
name=name,
|
||||
quote=quote
|
||||
)
|
||||
self.columns = columns
|
||||
if name in registered_composites:
|
||||
self.type_cls = registered_composites[name].type_cls
|
||||
else:
|
||||
self.type_cls = namedtuple(
|
||||
self.name, [c.name for c in columns]
|
||||
)
|
||||
registered_composites[name] = self
|
||||
|
||||
class Caster(CompositeCaster):
|
||||
def make(obj, values):
|
||||
return self.type_cls(*values)
|
||||
|
||||
self.caster = Caster
|
||||
attach_composite_listeners()
|
||||
|
||||
def get_col_spec(self):
|
||||
return self.name
|
||||
|
||||
def bind_processor(self, dialect):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
processed_value = []
|
||||
for i, column in enumerate(self.columns):
|
||||
current_value = (
|
||||
value.get(column.name)
|
||||
if isinstance(value, dict)
|
||||
else value[i]
|
||||
)
|
||||
|
||||
if isinstance(column.type, TypeDecorator):
|
||||
processed_value.append(
|
||||
column.type.process_bind_param(
|
||||
current_value, dialect
|
||||
)
|
||||
)
|
||||
else:
|
||||
processed_value.append(current_value)
|
||||
return self.type_cls(*processed_value)
|
||||
return process
|
||||
|
||||
def result_processor(self, dialect, coltype):
|
||||
def process(value):
|
||||
if value is None:
|
||||
return None
|
||||
cls = value.__class__
|
||||
kwargs = {}
|
||||
for column in self.columns:
|
||||
if isinstance(column.type, TypeDecorator):
|
||||
kwargs[column.name] = column.type.process_result_value(
|
||||
getattr(value, column.name), dialect
|
||||
)
|
||||
else:
|
||||
kwargs[column.name] = getattr(value, column.name)
|
||||
return cls(**kwargs)
|
||||
return process
|
||||
|
||||
def create(self, bind=None, checkfirst=None):
|
||||
if (
|
||||
not checkfirst or
|
||||
not bind.dialect.has_type(bind, self.name, schema=self.schema)
|
||||
):
|
||||
bind.execute(CreateCompositeType(self))
|
||||
|
||||
def drop(self, bind=None, checkfirst=True):
|
||||
if (
|
||||
checkfirst and
|
||||
bind.dialect.has_type(bind, self.name, schema=self.schema)
|
||||
):
|
||||
bind.execute(DropCompositeType(self))
|
||||
|
||||
|
||||
def register_psycopg2_composite(dbapi_connection, composite):
|
||||
psycopg2.extras.register_composite(
|
||||
composite.name,
|
||||
dbapi_connection,
|
||||
globally=True,
|
||||
factory=composite.caster
|
||||
)
|
||||
|
||||
def adapt_composite(value):
|
||||
dialect = PGDialect_psycopg2()
|
||||
adapted = [
|
||||
adapt(
|
||||
getattr(value, column.name)
|
||||
if not isinstance(column.type, TypeDecorator)
|
||||
else column.type.process_bind_param(
|
||||
getattr(value, column.name),
|
||||
dialect
|
||||
)
|
||||
)
|
||||
for column in
|
||||
composite.columns
|
||||
]
|
||||
for value in adapted:
|
||||
if hasattr(value, 'prepare'):
|
||||
value.prepare(dbapi_connection)
|
||||
values = [
|
||||
value.getquoted().decode(dbapi_connection.encoding)
|
||||
for value in adapted
|
||||
]
|
||||
return AsIs(
|
||||
'({})::{}'.format(
|
||||
', '.join(values),
|
||||
dialect.identifier_preparer.quote(composite.name)
|
||||
)
|
||||
)
|
||||
|
||||
register_adapter(composite.type_cls, adapt_composite)
|
||||
|
||||
|
||||
def get_driver_connection(connection):
|
||||
try:
|
||||
# SQLAlchemy 2.0
|
||||
return connection.connection.driver_connection
|
||||
except AttributeError:
|
||||
return connection.connection.connection
|
||||
|
||||
|
||||
def before_create(target, connection, **kw):
|
||||
for name, composite in registered_composites.items():
|
||||
composite.create(connection, checkfirst=True)
|
||||
register_psycopg2_composite(
|
||||
get_driver_connection(connection),
|
||||
composite
|
||||
)
|
||||
|
||||
|
||||
def after_drop(target, connection, **kw):
|
||||
for name, composite in registered_composites.items():
|
||||
composite.drop(connection, checkfirst=True)
|
||||
|
||||
|
||||
def register_composites(connection):
|
||||
for name, composite in registered_composites.items():
|
||||
register_psycopg2_composite(
|
||||
get_driver_connection(connection),
|
||||
composite
|
||||
)
|
||||
|
||||
|
||||
def attach_composite_listeners():
|
||||
listeners = [
|
||||
(sa.MetaData, 'before_create', before_create),
|
||||
(sa.MetaData, 'after_drop', after_drop),
|
||||
]
|
||||
for listener in listeners:
|
||||
if not sa.event.contains(*listener):
|
||||
sa.event.listen(*listener)
|
||||
|
||||
|
||||
def remove_composite_listeners():
|
||||
listeners = [
|
||||
(sa.MetaData, 'before_create', before_create),
|
||||
(sa.MetaData, 'after_drop', after_drop),
|
||||
]
|
||||
for listener in listeners:
|
||||
if sa.event.contains(*listener):
|
||||
sa.event.remove(*listener)
|
||||
|
||||
|
||||
class CreateCompositeType(_CreateDropBase):
|
||||
pass
|
||||
|
||||
|
||||
@compiles(CreateCompositeType)
|
||||
def _visit_create_composite_type(create, compiler, **kw):
|
||||
type_ = create.element
|
||||
fields = ', '.join(
|
||||
'{name} {type}'.format(
|
||||
name=column.name,
|
||||
type=compiler.dialect.type_compiler.process(
|
||||
to_instance(column.type)
|
||||
)
|
||||
)
|
||||
for column in type_.columns
|
||||
)
|
||||
|
||||
return 'CREATE TYPE {name} AS ({fields})'.format(
|
||||
name=compiler.preparer.format_type(type_),
|
||||
fields=fields
|
||||
)
|
||||
|
||||
|
||||
class DropCompositeType(_CreateDropBase):
|
||||
pass
|
||||
|
||||
|
||||
@compiles(DropCompositeType)
|
||||
def _visit_drop_composite_type(drop, compiler, **kw):
|
||||
type_ = drop.element
|
||||
|
||||
return f'DROP TYPE {compiler.preparer.format_type(type_)}'
|
|
@ -0,0 +1,204 @@
|
|||
"""
|
||||
.. note::
|
||||
|
||||
The `phonenumbers`_ package must be installed to use PhoneNumber types.
|
||||
|
||||
.. _phonenumbers: https://github.com/daviddrysdale/python-phonenumbers
|
||||
"""
|
||||
|
||||
from sqlalchemy import exc, types
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from ..utils import str_coercible
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
try:
|
||||
import phonenumbers
|
||||
from phonenumbers.phonenumber import PhoneNumber as BasePhoneNumber
|
||||
from phonenumbers.phonenumberutil import NumberParseException
|
||||
except ImportError:
|
||||
phonenumbers = None
|
||||
BasePhoneNumber = object
|
||||
NumberParseException = Exception
|
||||
|
||||
|
||||
class PhoneNumberParseException(NumberParseException, exc.DontWrapMixin):
|
||||
"""
|
||||
Wraps exceptions from phonenumbers with SQLAlchemy's DontWrapMixin
|
||||
so we get more meaningful exceptions on validation failure instead of the
|
||||
StatementException
|
||||
|
||||
Clients can catch this as either a PhoneNumberParseException or
|
||||
NumberParseException from the phonenumbers library.
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
@str_coercible
|
||||
class PhoneNumber(BasePhoneNumber):
|
||||
"""
|
||||
Extends a PhoneNumber class from `Python phonenumbers library`_. Adds
|
||||
different phone number formats to attributes, so they can be easily used
|
||||
in templates. Phone number validation method is also implemented.
|
||||
|
||||
Takes the raw phone number and country code as params and parses them
|
||||
into a PhoneNumber object.
|
||||
|
||||
.. _Python phonenumbers library:
|
||||
https://github.com/daviddrysdale/python-phonenumbers
|
||||
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import PhoneNumber
|
||||
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
_phone_number = sa.Column(sa.Unicode(20))
|
||||
country_code = sa.Column(sa.Unicode(8))
|
||||
|
||||
phone_number = sa.orm.composite(
|
||||
PhoneNumber,
|
||||
_phone_number,
|
||||
country_code
|
||||
)
|
||||
|
||||
|
||||
user = User(phone_number=PhoneNumber('0401234567', 'FI'))
|
||||
|
||||
user.phone_number.e164 # '+358401234567'
|
||||
user.phone_number.international # '+358 40 1234567'
|
||||
user.phone_number.national # '040 1234567'
|
||||
user.country_code # 'FI'
|
||||
|
||||
|
||||
:param raw_number:
|
||||
String representation of the phone number.
|
||||
:param region:
|
||||
Region of the phone number.
|
||||
:param check_region:
|
||||
Whether to check the supplied region parameter;
|
||||
should always be True for external callers.
|
||||
Can be useful for short codes or toll free
|
||||
"""
|
||||
|
||||
def __init__(self, raw_number, region=None, check_region=True):
|
||||
# Bail if phonenumbers is not found.
|
||||
if phonenumbers is None:
|
||||
raise ImproperlyConfigured(
|
||||
"The 'phonenumbers' package is required to use 'PhoneNumber'"
|
||||
)
|
||||
|
||||
try:
|
||||
self._phone_number = phonenumbers.parse(
|
||||
raw_number, region, _check_region=check_region
|
||||
)
|
||||
except NumberParseException as e:
|
||||
# Wrap exception so SQLAlchemy doesn't swallow it as a
|
||||
# StatementError
|
||||
#
|
||||
# Worth noting that if -1 shows up as the error_type
|
||||
# it's likely because the API has changed upstream and these
|
||||
# bindings need to be updated.
|
||||
raise PhoneNumberParseException(getattr(e, "error_type", -1), str(e))
|
||||
|
||||
super().__init__(
|
||||
country_code=self._phone_number.country_code,
|
||||
national_number=self._phone_number.national_number,
|
||||
extension=self._phone_number.extension,
|
||||
italian_leading_zero=self._phone_number.italian_leading_zero,
|
||||
raw_input=self._phone_number.raw_input,
|
||||
country_code_source=self._phone_number.country_code_source,
|
||||
preferred_domestic_carrier_code=(
|
||||
self._phone_number.preferred_domestic_carrier_code
|
||||
),
|
||||
)
|
||||
self.region = region
|
||||
self.national = phonenumbers.format_number(
|
||||
self._phone_number, phonenumbers.PhoneNumberFormat.NATIONAL
|
||||
)
|
||||
self.international = phonenumbers.format_number(
|
||||
self._phone_number, phonenumbers.PhoneNumberFormat.INTERNATIONAL
|
||||
)
|
||||
self.e164 = phonenumbers.format_number(
|
||||
self._phone_number, phonenumbers.PhoneNumberFormat.E164
|
||||
)
|
||||
|
||||
def __composite_values__(self):
|
||||
return self.national, self.region
|
||||
|
||||
def is_valid_number(self):
|
||||
return phonenumbers.is_valid_number(self._phone_number)
|
||||
|
||||
def __unicode__(self):
|
||||
return self.national
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self.e164)
|
||||
|
||||
|
||||
class PhoneNumberType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
Changes PhoneNumber objects to a string representation on the way in and
|
||||
changes them back to PhoneNumber objects on the way out. If E164 is used
|
||||
as storing format, no country code is needed for parsing the database
|
||||
value to PhoneNumber object.
|
||||
|
||||
::
|
||||
|
||||
class User(self.Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True, primary_key=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
phone_number = sa.Column(PhoneNumberType())
|
||||
|
||||
|
||||
user = User(phone_number='+358401234567')
|
||||
|
||||
user.phone_number.e164 # '+358401234567'
|
||||
user.phone_number.international # '+358 40 1234567'
|
||||
user.phone_number.national # '040 1234567'
|
||||
"""
|
||||
|
||||
STORE_FORMAT = "e164"
|
||||
impl = types.Unicode(20)
|
||||
python_type = PhoneNumber
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, region="US", max_length=20, *args, **kwargs):
|
||||
# Bail if phonenumbers is not found.
|
||||
if phonenumbers is None:
|
||||
raise ImproperlyConfigured(
|
||||
"The 'phonenumbers' package is required to use 'PhoneNumberType'"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
self.region = region
|
||||
self.impl = types.Unicode(max_length)
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value:
|
||||
if not isinstance(value, PhoneNumber):
|
||||
value = PhoneNumber(value, region=self.region)
|
||||
|
||||
if self.STORE_FORMAT == "e164" and value.extension:
|
||||
return f"{value.e164};ext={value.extension}"
|
||||
|
||||
return getattr(value, self.STORE_FORMAT)
|
||||
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value:
|
||||
return PhoneNumber(value, self.region)
|
||||
return value
|
||||
|
||||
def _coerce(self, value):
|
||||
if value and not isinstance(value, PhoneNumber):
|
||||
value = PhoneNumber(value, region=self.region)
|
||||
|
||||
return value or None
|
|
@ -0,0 +1,480 @@
|
|||
"""
|
||||
SQLAlchemy-Utils provides wide variety of range data types. All range data
|
||||
types return Interval objects of intervals_ package. In order to use range data
|
||||
types you need to install intervals_ with:
|
||||
|
||||
::
|
||||
|
||||
pip install intervals
|
||||
|
||||
|
||||
Intervals package provides good chunk of additional interval operators that for
|
||||
example psycopg2 range objects do not support.
|
||||
|
||||
|
||||
|
||||
Some good reading for practical interval implementations:
|
||||
|
||||
https://wiki.postgresql.org/images/f/f0/Range-types.pdf
|
||||
|
||||
|
||||
Range type initialization
|
||||
-------------------------
|
||||
|
||||
::
|
||||
|
||||
|
||||
|
||||
from sqlalchemy_utils import IntRangeType
|
||||
|
||||
|
||||
class Event(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
estimated_number_of_persons = sa.Column(IntRangeType)
|
||||
|
||||
|
||||
|
||||
You can also set a step parameter for range type. The values that are not
|
||||
multipliers of given step will be rounded up to nearest step multiplier.
|
||||
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import IntRangeType
|
||||
|
||||
|
||||
class Event(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
estimated_number_of_persons = sa.Column(IntRangeType(step=1000))
|
||||
|
||||
|
||||
event = Event(estimated_number_of_persons=[100, 1200])
|
||||
event.estimated_number_of_persons.lower # 0
|
||||
event.estimated_number_of_persons.upper # 1000
|
||||
|
||||
|
||||
Range type operators
|
||||
--------------------
|
||||
|
||||
SQLAlchemy-Utils supports many range type operators. These operators follow the
|
||||
`intervals` package interval coercion rules.
|
||||
|
||||
So for example when we make a query such as:
|
||||
|
||||
::
|
||||
|
||||
session.query(Car).filter(Car.price_range == 300)
|
||||
|
||||
|
||||
It is essentially the same as:
|
||||
|
||||
::
|
||||
|
||||
session.query(Car).filter(Car.price_range == DecimalInterval([300, 300]))
|
||||
|
||||
|
||||
Comparison operators
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
All range types support all comparison operators (>, >=, ==, !=, <=, <).
|
||||
|
||||
::
|
||||
|
||||
Car.price_range < [12, 300]
|
||||
|
||||
Car.price_range == [12, 300]
|
||||
|
||||
Car.price_range < 300
|
||||
|
||||
Car.price_range > (300, 500)
|
||||
|
||||
# Whether or not range is strictly left of another range
|
||||
Car.price_range << [300, 500]
|
||||
|
||||
# Whether or not range is strictly right of another range
|
||||
Car.price_range >> [300, 500]
|
||||
|
||||
|
||||
|
||||
Membership operators
|
||||
^^^^^^^^^^^^^^^^^^^^
|
||||
|
||||
::
|
||||
|
||||
Car.price_range.contains([300, 500])
|
||||
|
||||
Car.price_range.contained_by([300, 500])
|
||||
|
||||
Car.price_range.in_([[300, 500], [800, 900]])
|
||||
|
||||
~ Car.price_range.in_([[300, 400], [700, 800]])
|
||||
|
||||
|
||||
Length
|
||||
^^^^^^
|
||||
|
||||
SQLAlchemy-Utils provides length property for all range types. The
|
||||
implementation of this property varies on different range types.
|
||||
|
||||
In the following example we find all cars whose price range's length is more
|
||||
than 500.
|
||||
|
||||
::
|
||||
|
||||
session.query(Car).filter(
|
||||
Car.price_range.length > 500
|
||||
)
|
||||
|
||||
|
||||
|
||||
.. _intervals: https://github.com/kvesteri/intervals
|
||||
"""
|
||||
from collections.abc import Iterable
|
||||
from datetime import timedelta
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import types
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
DATERANGE,
|
||||
INT4RANGE,
|
||||
INT8RANGE,
|
||||
NUMRANGE,
|
||||
TSRANGE
|
||||
)
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
intervals = None
|
||||
try:
|
||||
import intervals
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class RangeComparator(types.TypeEngine.Comparator):
|
||||
@classmethod
|
||||
def coerced_func(cls, func):
|
||||
def operation(self, other, **kwargs):
|
||||
other = self.coerce_arg(other)
|
||||
return getattr(types.TypeEngine.Comparator, func)(
|
||||
self, other, **kwargs
|
||||
)
|
||||
return operation
|
||||
|
||||
def coerce_arg(self, other):
|
||||
coerced_types = (
|
||||
self.type.interval_class.type,
|
||||
tuple,
|
||||
list,
|
||||
str,
|
||||
)
|
||||
|
||||
if isinstance(other, coerced_types):
|
||||
return self.type.interval_class(other)
|
||||
return other
|
||||
|
||||
def in_(self, other):
|
||||
if (
|
||||
isinstance(other, Iterable) and
|
||||
not isinstance(other, str)
|
||||
):
|
||||
other = map(self.coerce_arg, other)
|
||||
return super().in_(other)
|
||||
|
||||
def notin_(self, other):
|
||||
if (
|
||||
isinstance(other, Iterable) and
|
||||
not isinstance(other, str)
|
||||
):
|
||||
other = map(self.coerce_arg, other)
|
||||
return super().notin_(other)
|
||||
|
||||
def __rshift__(self, other, **kwargs):
|
||||
"""
|
||||
Returns whether or not given interval is strictly right of another
|
||||
interval.
|
||||
|
||||
[a, b] >> [c, d] True, if a > d
|
||||
"""
|
||||
other = self.coerce_arg(other)
|
||||
return self.op('>>')(other)
|
||||
|
||||
def __lshift__(self, other, **kwargs):
|
||||
"""
|
||||
Returns whether or not given interval is strictly left of another
|
||||
interval.
|
||||
|
||||
[a, b] << [c, d] True, if b < c
|
||||
"""
|
||||
other = self.coerce_arg(other)
|
||||
return self.op('<<')(other)
|
||||
|
||||
def contains(self, other, **kwargs):
|
||||
other = self.coerce_arg(other)
|
||||
return self.op('@>')(other)
|
||||
|
||||
def contained_by(self, other, **kwargs):
|
||||
other = self.coerce_arg(other)
|
||||
return self.op('<@')(other)
|
||||
|
||||
|
||||
class DiscreteRangeComparator(RangeComparator):
|
||||
@property
|
||||
def length(self):
|
||||
return sa.func.upper(self.expr) - self.step - sa.func.lower(self.expr)
|
||||
|
||||
|
||||
class IntRangeComparator(DiscreteRangeComparator):
|
||||
step = 1
|
||||
|
||||
|
||||
class DateRangeComparator(DiscreteRangeComparator):
|
||||
step = timedelta(days=1)
|
||||
|
||||
|
||||
class ContinuousRangeComparator(RangeComparator):
|
||||
@property
|
||||
def length(self):
|
||||
return sa.func.upper(self.expr) - sa.func.lower(self.expr)
|
||||
|
||||
|
||||
funcs = [
|
||||
'__eq__',
|
||||
'__ne__',
|
||||
'__lt__',
|
||||
'__le__',
|
||||
'__gt__',
|
||||
'__ge__',
|
||||
]
|
||||
|
||||
|
||||
for func in funcs:
|
||||
setattr(
|
||||
RangeComparator,
|
||||
func,
|
||||
RangeComparator.coerced_func(func)
|
||||
)
|
||||
|
||||
|
||||
class RangeType(ScalarCoercible, types.TypeDecorator):
|
||||
comparator_factory = RangeComparator
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if intervals is None:
|
||||
raise ImproperlyConfigured(
|
||||
'RangeType needs intervals package installed.'
|
||||
)
|
||||
self.step = kwargs.pop('step', None)
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if dialect.name == 'postgresql':
|
||||
# Use the native range type for postgres.
|
||||
return dialect.type_descriptor(self.impl)
|
||||
else:
|
||||
# Other drivers don't have native types.
|
||||
return dialect.type_descriptor(sa.String(255))
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is not None:
|
||||
return str(value)
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if isinstance(value, str):
|
||||
factory_func = self.interval_class.from_string
|
||||
else:
|
||||
factory_func = self.interval_class
|
||||
if value is not None:
|
||||
if self.interval_class.step is not None:
|
||||
return self.canonicalize_result_value(
|
||||
factory_func(value, step=self.step)
|
||||
)
|
||||
else:
|
||||
return factory_func(value, step=self.step)
|
||||
return value
|
||||
|
||||
def canonicalize_result_value(self, value):
|
||||
return intervals.canonicalize(value, True, True)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is None:
|
||||
return None
|
||||
return self.interval_class(value, step=self.step)
|
||||
|
||||
|
||||
class IntRangeType(RangeType):
|
||||
"""
|
||||
IntRangeType provides way for saving ranges of integers into database. On
|
||||
PostgreSQL this type maps to native INT4RANGE type while on other drivers
|
||||
this maps to simple string column.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import IntRangeType
|
||||
|
||||
|
||||
class Event(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
estimated_number_of_persons = sa.Column(IntRangeType)
|
||||
|
||||
|
||||
party = Event(name='party')
|
||||
|
||||
# we estimate the party to contain minium of 10 persons and at max
|
||||
# 100 persons
|
||||
party.estimated_number_of_persons = [10, 100]
|
||||
|
||||
print party.estimated_number_of_persons
|
||||
# '10-100'
|
||||
|
||||
|
||||
IntRangeType returns the values as IntInterval objects. These objects
|
||||
support many arithmetic operators::
|
||||
|
||||
|
||||
meeting = Event(name='meeting')
|
||||
|
||||
meeting.estimated_number_of_persons = [20, 40]
|
||||
|
||||
total = (
|
||||
meeting.estimated_number_of_persons +
|
||||
party.estimated_number_of_persons
|
||||
)
|
||||
print total
|
||||
# '30-140'
|
||||
"""
|
||||
impl = INT4RANGE
|
||||
comparator_factory = IntRangeComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval_class = intervals.IntInterval
|
||||
|
||||
|
||||
class Int8RangeType(RangeType):
|
||||
"""
|
||||
Int8RangeType provides way for saving ranges of 8-byte integers into
|
||||
database. On PostgreSQL this type maps to native INT8RANGE type while on
|
||||
other drivers this maps to simple string column.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import IntRangeType
|
||||
|
||||
|
||||
class Event(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255))
|
||||
estimated_number_of_persons = sa.Column(Int8RangeType)
|
||||
|
||||
|
||||
party = Event(name='party')
|
||||
|
||||
# we estimate the party to contain minium of 10 persons and at max
|
||||
# 100 persons
|
||||
party.estimated_number_of_persons = [10, 100]
|
||||
|
||||
print party.estimated_number_of_persons
|
||||
# '10-100'
|
||||
|
||||
|
||||
Int8RangeType returns the values as IntInterval objects. These objects
|
||||
support many arithmetic operators::
|
||||
|
||||
|
||||
meeting = Event(name='meeting')
|
||||
|
||||
meeting.estimated_number_of_persons = [20, 40]
|
||||
|
||||
total = (
|
||||
meeting.estimated_number_of_persons +
|
||||
party.estimated_number_of_persons
|
||||
)
|
||||
print total
|
||||
# '30-140'
|
||||
"""
|
||||
impl = INT8RANGE
|
||||
comparator_factory = IntRangeComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval_class = intervals.IntInterval
|
||||
|
||||
|
||||
class DateRangeType(RangeType):
|
||||
"""
|
||||
DateRangeType provides way for saving ranges of dates into database. On
|
||||
PostgreSQL this type maps to native DATERANGE type while on other drivers
|
||||
this maps to simple string column.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import DateRangeType
|
||||
|
||||
|
||||
class Reservation(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
room_id = sa.Column(sa.Integer))
|
||||
during = sa.Column(DateRangeType)
|
||||
"""
|
||||
impl = DATERANGE
|
||||
comparator_factory = DateRangeComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval_class = intervals.DateInterval
|
||||
|
||||
|
||||
class NumericRangeType(RangeType):
|
||||
"""
|
||||
NumericRangeType provides way for saving ranges of decimals into database.
|
||||
On PostgreSQL this type maps to native NUMRANGE type while on other drivers
|
||||
this maps to simple string column.
|
||||
|
||||
Example::
|
||||
|
||||
|
||||
from sqlalchemy_utils import NumericRangeType
|
||||
|
||||
|
||||
class Car(Base):
|
||||
__tablename__ = 'car'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
name = sa.Column(sa.Unicode(255)))
|
||||
price_range = sa.Column(NumericRangeType)
|
||||
"""
|
||||
|
||||
impl = NUMRANGE
|
||||
comparator_factory = ContinuousRangeComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval_class = intervals.DecimalInterval
|
||||
|
||||
|
||||
class DateTimeRangeType(RangeType):
|
||||
impl = TSRANGE
|
||||
comparator_factory = ContinuousRangeComparator
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.interval_class = intervals.DateTimeInterval
|
|
@ -0,0 +1,6 @@
|
|||
class ScalarCoercible:
|
||||
def _coerce(self, value):
|
||||
raise NotImplementedError
|
||||
|
||||
def coercion_listener(self, target, value, oldvalue, initiator):
|
||||
return self._coerce(value)
|
|
@ -0,0 +1,97 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy import types
|
||||
|
||||
|
||||
class ScalarListException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class ScalarListType(types.TypeDecorator):
|
||||
"""
|
||||
ScalarListType type provides convenient way for saving multiple scalar
|
||||
values in one column. ScalarListType works like list on python side and
|
||||
saves the result as comma-separated list in the database (custom separators
|
||||
can also be used).
|
||||
|
||||
Example ::
|
||||
|
||||
|
||||
from sqlalchemy_utils import ScalarListType
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
hobbies = sa.Column(ScalarListType())
|
||||
|
||||
|
||||
user = User()
|
||||
user.hobbies = ['football', 'ice_hockey']
|
||||
session.commit()
|
||||
|
||||
|
||||
You can easily set up integer lists too:
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import ScalarListType
|
||||
|
||||
|
||||
class Player(Base):
|
||||
__tablename__ = 'player'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
points = sa.Column(ScalarListType(int))
|
||||
|
||||
|
||||
player = Player()
|
||||
player.points = [11, 12, 8, 80]
|
||||
session.commit()
|
||||
|
||||
|
||||
ScalarListType is always stored as text. To use an array field on
|
||||
PostgreSQL database use variant construct::
|
||||
|
||||
from sqlalchemy_utils import ScalarListType
|
||||
|
||||
|
||||
class Player(Base):
|
||||
__tablename__ = 'player'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
points = sa.Column(
|
||||
ARRAY(Integer).with_variant(ScalarListType(int), 'sqlite')
|
||||
)
|
||||
|
||||
|
||||
"""
|
||||
impl = sa.UnicodeText()
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, coerce_func=str, separator=','):
|
||||
self.separator = str(separator)
|
||||
self.coerce_func = coerce_func
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
# Convert list of values to unicode separator-separated list
|
||||
# Example: [1, 2, 3, 4] -> '1, 2, 3, 4'
|
||||
if value is not None:
|
||||
if any(self.separator in str(item) for item in value):
|
||||
raise ScalarListException(
|
||||
"List values can't contain string '%s' (its being used as "
|
||||
"separator. If you wish for scalar list values to contain "
|
||||
"these strings, use a different separator string.)"
|
||||
% self.separator
|
||||
)
|
||||
return self.separator.join(
|
||||
map(str, value)
|
||||
)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
if value == '':
|
||||
return []
|
||||
# coerce each value
|
||||
return list(map(
|
||||
self.coerce_func, value.split(self.separator)
|
||||
))
|
|
@ -0,0 +1,102 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class TimezoneType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
TimezoneType provides a way for saving timezones objects into database.
|
||||
TimezoneType saves timezone objects as strings on the way in and converts
|
||||
them back to objects when querying the database.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import TimezoneType
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
# Pass backend='pytz' to change it to use pytz. Other values:
|
||||
# 'dateutil' (default), and 'zoneinfo'.
|
||||
timezone = sa.Column(TimezoneType(backend='pytz'))
|
||||
|
||||
:param backend: Whether to use 'dateutil', 'pytz' or 'zoneinfo' for
|
||||
timezones. 'zoneinfo' uses the standard library module in Python 3.9+,
|
||||
but requires the external 'backports.zoneinfo' package for older
|
||||
Python versions.
|
||||
|
||||
"""
|
||||
|
||||
impl = types.Unicode(50)
|
||||
python_type = None
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, backend='dateutil'):
|
||||
self.backend = backend
|
||||
if backend == 'dateutil':
|
||||
try:
|
||||
from dateutil.tz import tzfile
|
||||
from dateutil.zoneinfo import get_zonefile_instance
|
||||
|
||||
self.python_type = tzfile
|
||||
self._to = get_zonefile_instance().zones.get
|
||||
self._from = lambda x: str(x._filename)
|
||||
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
"'python-dateutil' is required to use the "
|
||||
"'dateutil' backend for 'TimezoneType'"
|
||||
)
|
||||
|
||||
elif backend == 'pytz':
|
||||
try:
|
||||
from pytz import timezone
|
||||
from pytz.tzinfo import BaseTzInfo
|
||||
|
||||
self.python_type = BaseTzInfo
|
||||
self._to = timezone
|
||||
self._from = str
|
||||
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
"'pytz' is required to use the 'pytz' backend "
|
||||
"for 'TimezoneType'"
|
||||
)
|
||||
|
||||
elif backend == "zoneinfo":
|
||||
try:
|
||||
import zoneinfo
|
||||
except ImportError:
|
||||
try:
|
||||
from backports import zoneinfo
|
||||
except ImportError:
|
||||
raise ImproperlyConfigured(
|
||||
"'backports.zoneinfo' is required to use "
|
||||
"the 'zoneinfo' backend for 'TimezoneType'"
|
||||
"on Python version < 3.9"
|
||||
)
|
||||
|
||||
self.python_type = zoneinfo.ZoneInfo
|
||||
self._to = zoneinfo.ZoneInfo
|
||||
self._from = str
|
||||
|
||||
else:
|
||||
raise ImproperlyConfigured(
|
||||
"'pytz', 'dateutil' or 'zoneinfo' are the backends "
|
||||
"supported for 'TimezoneType'"
|
||||
)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, self.python_type):
|
||||
obj = self._to(value)
|
||||
if obj is None:
|
||||
raise ValueError("unknown time zone '%s'" % value)
|
||||
return obj
|
||||
return value
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return self._from(self._coerce(value)) if value else None
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return self._to(value) if value else None
|
|
@ -0,0 +1,108 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import TSVECTOR
|
||||
|
||||
|
||||
class TSVectorType(sa.types.TypeDecorator):
|
||||
"""
|
||||
.. note::
|
||||
|
||||
This type is PostgreSQL specific and is not supported by other
|
||||
dialects.
|
||||
|
||||
Provides additional functionality for SQLAlchemy PostgreSQL dialect's
|
||||
TSVECTOR_ type. This additional functionality includes:
|
||||
|
||||
* Vector concatenation
|
||||
* regconfig constructor parameter which is applied to match function if no
|
||||
postgresql_regconfig parameter is given
|
||||
* Provides extensible base for extensions such as SQLAlchemy-Searchable_
|
||||
|
||||
.. _TSVECTOR:
|
||||
https://docs.sqlalchemy.org/en/latest/dialects/postgresql.html#full-text-search
|
||||
|
||||
.. _SQLAlchemy-Searchable:
|
||||
https://www.github.com/kvesteri/sqlalchemy-searchable
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import TSVectorType
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
search_vector = sa.Column(TSVectorType)
|
||||
|
||||
|
||||
# Find all articles whose name matches 'finland'
|
||||
session.query(Article).filter(Article.search_vector.match('finland'))
|
||||
|
||||
|
||||
TSVectorType also supports vector concatenation.
|
||||
|
||||
::
|
||||
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
name_vector = sa.Column(TSVectorType)
|
||||
content = sa.Column(sa.String)
|
||||
content_vector = sa.Column(TSVectorType)
|
||||
|
||||
# Find all articles whose name or content matches 'finland'
|
||||
session.query(Article).filter(
|
||||
(Article.name_vector | Article.content_vector).match('finland')
|
||||
)
|
||||
|
||||
You can configure TSVectorType to use a specific regconfig.
|
||||
::
|
||||
|
||||
class Article(Base):
|
||||
__tablename__ = 'user'
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
name = sa.Column(sa.String(100))
|
||||
search_vector = sa.Column(
|
||||
TSVectorType(regconfig='pg_catalog.simple')
|
||||
)
|
||||
|
||||
|
||||
Now expression such as::
|
||||
|
||||
|
||||
Article.search_vector.match('finland')
|
||||
|
||||
|
||||
Would be equivalent to SQL::
|
||||
|
||||
|
||||
search_vector @@ to_tsquery('pg_catalog.simple', 'finland')
|
||||
|
||||
"""
|
||||
impl = TSVECTOR
|
||||
cache_ok = True
|
||||
|
||||
class comparator_factory(TSVECTOR.Comparator):
|
||||
def match(self, other, **kwargs):
|
||||
if 'postgresql_regconfig' not in kwargs:
|
||||
if 'regconfig' in self.type.options:
|
||||
kwargs['postgresql_regconfig'] = (
|
||||
self.type.options['regconfig']
|
||||
)
|
||||
return TSVECTOR.Comparator.match(self, other, **kwargs)
|
||||
|
||||
def __or__(self, other):
|
||||
return self.op('||')(other)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
"""
|
||||
Initializes new TSVectorType
|
||||
|
||||
:param *args: list of column names
|
||||
:param **kwargs: various other options for this TSVectorType
|
||||
"""
|
||||
self.columns = args
|
||||
self.options = kwargs
|
||||
super().__init__()
|
|
@ -0,0 +1,68 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
furl = None
|
||||
try:
|
||||
from furl import furl
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class URLType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
URLType stores furl_ objects into database.
|
||||
|
||||
.. _furl: https://github.com/gruns/furl
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import URLType
|
||||
from furl import furl
|
||||
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
id = sa.Column(sa.Integer, primary_key=True)
|
||||
website = sa.Column(URLType)
|
||||
|
||||
|
||||
user = User(website='www.example.com')
|
||||
|
||||
# website is coerced to furl object, hence all nice furl operations
|
||||
# come available
|
||||
user.website.args['some_argument'] = '12'
|
||||
|
||||
print user.website
|
||||
# www.example.com?some_argument=12
|
||||
"""
|
||||
|
||||
impl = types.UnicodeText
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if furl is not None and isinstance(value, furl):
|
||||
return str(value)
|
||||
|
||||
if isinstance(value, str):
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if furl is None:
|
||||
return value
|
||||
|
||||
if value is not None:
|
||||
return furl(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if furl is None:
|
||||
return value
|
||||
|
||||
if value is not None and not isinstance(value, furl):
|
||||
return furl(value)
|
||||
return value
|
||||
|
||||
@property
|
||||
def python_type(self):
|
||||
return self.impl.type.python_type
|
|
@ -0,0 +1,113 @@
|
|||
import uuid
|
||||
|
||||
from sqlalchemy import types, util
|
||||
from sqlalchemy.dialects import mssql, postgresql
|
||||
|
||||
from ..compat import get_sqlalchemy_version
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
sqlalchemy_version = get_sqlalchemy_version()
|
||||
|
||||
|
||||
class UUIDType(ScalarCoercible, types.TypeDecorator):
|
||||
"""
|
||||
Stores a UUID in the database natively when it can and falls back to
|
||||
a BINARY(16) or a CHAR(32) when it can't.
|
||||
|
||||
::
|
||||
|
||||
from sqlalchemy_utils import UUIDType
|
||||
import uuid
|
||||
|
||||
class User(Base):
|
||||
__tablename__ = 'user'
|
||||
|
||||
# Pass `binary=False` to fallback to CHAR instead of BINARY
|
||||
id = sa.Column(
|
||||
UUIDType(binary=False),
|
||||
primary_key=True,
|
||||
default=uuid.uuid4
|
||||
)
|
||||
"""
|
||||
impl = types.BINARY(16)
|
||||
|
||||
python_type = uuid.UUID
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, binary=True, native=True):
|
||||
"""
|
||||
:param binary: Whether to use a BINARY(16) or CHAR(32) fallback.
|
||||
"""
|
||||
self.binary = binary
|
||||
self.native = native
|
||||
|
||||
def __repr__(self):
|
||||
return util.generic_repr(self)
|
||||
|
||||
def load_dialect_impl(self, dialect):
|
||||
if self.native and dialect.name in ('postgresql', 'cockroachdb'):
|
||||
# Use the native UUID type.
|
||||
return dialect.type_descriptor(postgresql.UUID())
|
||||
|
||||
if dialect.name == 'mssql' and self.native:
|
||||
# Use the native UNIQUEIDENTIFIER type.
|
||||
return dialect.type_descriptor(mssql.UNIQUEIDENTIFIER())
|
||||
|
||||
else:
|
||||
# Fallback to either a BINARY or a CHAR.
|
||||
kind = self.impl if self.binary else types.CHAR(32)
|
||||
return dialect.type_descriptor(kind)
|
||||
|
||||
@staticmethod
|
||||
def _coerce(value):
|
||||
if value and not isinstance(value, uuid.UUID):
|
||||
try:
|
||||
value = uuid.UUID(value)
|
||||
|
||||
except (TypeError, ValueError):
|
||||
value = uuid.UUID(bytes=value)
|
||||
|
||||
return value
|
||||
|
||||
# sqlalchemy >= 1.4.30 quotes UUID's automatically.
|
||||
# It is only necessary to quote UUID's in sqlalchemy < 1.4.30.
|
||||
if sqlalchemy_version < (1, 4, 30):
|
||||
def process_literal_param(self, value, dialect):
|
||||
return f"'{value}'" if value else value
|
||||
else:
|
||||
def process_literal_param(self, value, dialect):
|
||||
return value
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if not isinstance(value, uuid.UUID):
|
||||
value = self._coerce(value)
|
||||
|
||||
if self.native and dialect.name in (
|
||||
'postgresql',
|
||||
'mssql',
|
||||
'cockroachdb'
|
||||
):
|
||||
return str(value)
|
||||
|
||||
return value.bytes if self.binary else value.hex
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is None:
|
||||
return value
|
||||
|
||||
if self.native and dialect.name in (
|
||||
'postgresql',
|
||||
'mssql',
|
||||
'cockroachdb'
|
||||
):
|
||||
if isinstance(value, uuid.UUID):
|
||||
# Some drivers convert PostgreSQL's uuid values to
|
||||
# Python's uuid.UUID objects by themselves
|
||||
return value
|
||||
return uuid.UUID(value)
|
||||
|
||||
return uuid.UUID(bytes=value) if self.binary else uuid.UUID(value)
|
|
@ -0,0 +1,82 @@
|
|||
from sqlalchemy import types
|
||||
|
||||
from .. import i18n
|
||||
from ..exceptions import ImproperlyConfigured
|
||||
from ..primitives import WeekDay, WeekDays
|
||||
from .bit import BitType
|
||||
from .scalar_coercible import ScalarCoercible
|
||||
|
||||
|
||||
class WeekDaysType(types.TypeDecorator, ScalarCoercible):
|
||||
"""
|
||||
WeekDaysType offers way of saving WeekDays objects into database. The
|
||||
WeekDays objects are converted to bit strings on the way in and back to
|
||||
WeekDays objects on the way out.
|
||||
|
||||
In order to use WeekDaysType you need to install Babel_ first.
|
||||
|
||||
.. _Babel: https://babel.pocoo.org/
|
||||
|
||||
::
|
||||
|
||||
|
||||
from sqlalchemy_utils import WeekDaysType, WeekDays
|
||||
from babel import Locale
|
||||
|
||||
|
||||
class Schedule(Base):
|
||||
__tablename__ = 'schedule'
|
||||
id = sa.Column(sa.Integer, autoincrement=True)
|
||||
working_days = sa.Column(WeekDaysType)
|
||||
|
||||
|
||||
schedule = Schedule()
|
||||
schedule.working_days = WeekDays('0001111')
|
||||
session.add(schedule)
|
||||
session.commit()
|
||||
|
||||
print schedule.working_days # Thursday, Friday, Saturday, Sunday
|
||||
|
||||
|
||||
WeekDaysType also supports scalar coercion:
|
||||
|
||||
::
|
||||
|
||||
|
||||
schedule.working_days = '1110000'
|
||||
schedule.working_days # WeekDays object
|
||||
|
||||
"""
|
||||
|
||||
impl = BitType(WeekDay.NUM_WEEK_DAYS)
|
||||
|
||||
cache_ok = True
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if i18n.babel is None:
|
||||
raise ImproperlyConfigured(
|
||||
"'babel' package is required to use 'WeekDaysType'"
|
||||
)
|
||||
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def comparator_factory(self):
|
||||
return self.impl.comparator_factory
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
if isinstance(value, WeekDays):
|
||||
value = value.as_bit_string()
|
||||
|
||||
if dialect.name == 'mysql':
|
||||
return bytes(value, 'utf8')
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if value is not None:
|
||||
return WeekDays(value)
|
||||
|
||||
def _coerce(self, value):
|
||||
if value is not None and not isinstance(value, WeekDays):
|
||||
return WeekDays(value)
|
||||
return value
|
|
@ -0,0 +1,22 @@
|
|||
from collections.abc import Iterable
|
||||
|
||||
|
||||
def str_coercible(cls):
|
||||
def __str__(self):
|
||||
return self.__unicode__()
|
||||
|
||||
cls.__str__ = __str__
|
||||
return cls
|
||||
|
||||
|
||||
def is_sequence(value):
|
||||
return (
|
||||
isinstance(value, Iterable) and not isinstance(value, str)
|
||||
)
|
||||
|
||||
|
||||
def starts_with(iterable, prefix):
|
||||
"""
|
||||
Returns whether or not given iterable starts with given prefix.
|
||||
"""
|
||||
return list(iterable)[0:len(prefix)] == list(prefix)
|
210
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/view.py
Normal file
210
elitebot/lib/python3.11/site-packages/sqlalchemy_utils/view.py
Normal file
|
@ -0,0 +1,210 @@
|
|||
import sqlalchemy as sa
|
||||
from sqlalchemy.ext import compiler
|
||||
from sqlalchemy.schema import DDLElement, PrimaryKeyConstraint
|
||||
from sqlalchemy.sql.expression import ClauseElement, Executable
|
||||
|
||||
from sqlalchemy_utils.functions import get_columns
|
||||
|
||||
|
||||
class CreateView(DDLElement):
|
||||
def __init__(self, name, selectable, materialized=False):
|
||||
self.name = name
|
||||
self.selectable = selectable
|
||||
self.materialized = materialized
|
||||
|
||||
|
||||
@compiler.compiles(CreateView)
|
||||
def compile_create_materialized_view(element, compiler, **kw):
|
||||
return 'CREATE {}VIEW {} AS {}'.format(
|
||||
'MATERIALIZED ' if element.materialized else '',
|
||||
compiler.dialect.identifier_preparer.quote(element.name),
|
||||
compiler.sql_compiler.process(element.selectable, literal_binds=True),
|
||||
)
|
||||
|
||||
|
||||
class DropView(DDLElement):
|
||||
def __init__(self, name, materialized=False, cascade=True):
|
||||
self.name = name
|
||||
self.materialized = materialized
|
||||
self.cascade = cascade
|
||||
|
||||
|
||||
@compiler.compiles(DropView)
|
||||
def compile_drop_materialized_view(element, compiler, **kw):
|
||||
return 'DROP {}VIEW IF EXISTS {} {}'.format(
|
||||
'MATERIALIZED ' if element.materialized else '',
|
||||
compiler.dialect.identifier_preparer.quote(element.name),
|
||||
'CASCADE' if element.cascade else ''
|
||||
)
|
||||
|
||||
|
||||
def create_table_from_selectable(
|
||||
name,
|
||||
selectable,
|
||||
indexes=None,
|
||||
metadata=None,
|
||||
aliases=None,
|
||||
**kwargs
|
||||
):
|
||||
if indexes is None:
|
||||
indexes = []
|
||||
if metadata is None:
|
||||
metadata = sa.MetaData()
|
||||
if aliases is None:
|
||||
aliases = {}
|
||||
args = [
|
||||
sa.Column(
|
||||
c.name,
|
||||
c.type,
|
||||
key=aliases.get(c.name, c.name),
|
||||
primary_key=c.primary_key
|
||||
)
|
||||
for c in get_columns(selectable)
|
||||
] + indexes
|
||||
table = sa.Table(name, metadata, *args, **kwargs)
|
||||
|
||||
if not any([c.primary_key for c in get_columns(selectable)]):
|
||||
table.append_constraint(
|
||||
PrimaryKeyConstraint(*[c.name for c in get_columns(selectable)])
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def create_materialized_view(
|
||||
name,
|
||||
selectable,
|
||||
metadata,
|
||||
indexes=None,
|
||||
aliases=None
|
||||
):
|
||||
""" Create a view on a given metadata
|
||||
|
||||
:param name: The name of the view to create.
|
||||
:param selectable: An SQLAlchemy selectable e.g. a select() statement.
|
||||
:param metadata:
|
||||
An SQLAlchemy Metadata instance that stores the features of the
|
||||
database being described.
|
||||
:param indexes: An optional list of SQLAlchemy Index instances.
|
||||
:param aliases:
|
||||
An optional dictionary containing with keys as column names and values
|
||||
as column aliases.
|
||||
|
||||
Same as for ``create_view`` except that a ``CREATE MATERIALIZED VIEW``
|
||||
statement is emitted instead of a ``CREATE VIEW``.
|
||||
|
||||
"""
|
||||
table = create_table_from_selectable(
|
||||
name=name,
|
||||
selectable=selectable,
|
||||
indexes=indexes,
|
||||
metadata=None,
|
||||
aliases=aliases
|
||||
)
|
||||
|
||||
sa.event.listen(
|
||||
metadata,
|
||||
'after_create',
|
||||
CreateView(name, selectable, materialized=True)
|
||||
)
|
||||
|
||||
@sa.event.listens_for(metadata, 'after_create')
|
||||
def create_indexes(target, connection, **kw):
|
||||
for idx in table.indexes:
|
||||
idx.create(connection)
|
||||
|
||||
sa.event.listen(
|
||||
metadata,
|
||||
'before_drop',
|
||||
DropView(name, materialized=True)
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
def create_view(
|
||||
name,
|
||||
selectable,
|
||||
metadata,
|
||||
cascade_on_drop=True
|
||||
):
|
||||
""" Create a view on a given metadata
|
||||
|
||||
:param name: The name of the view to create.
|
||||
:param selectable: An SQLAlchemy selectable e.g. a select() statement.
|
||||
:param metadata:
|
||||
An SQLAlchemy Metadata instance that stores the features of the
|
||||
database being described.
|
||||
|
||||
The process for creating a view is similar to the standard way that a
|
||||
table is constructed, except that a selectable is provided instead of
|
||||
a set of columns. The view is created once a ``CREATE`` statement is
|
||||
executed against the supplied metadata (e.g. ``metadata.create_all(..)``),
|
||||
and dropped when a ``DROP`` is executed against the metadata.
|
||||
|
||||
To create a view that performs basic filtering on a table. ::
|
||||
|
||||
metadata = MetaData()
|
||||
users = Table('users', metadata,
|
||||
Column('id', Integer, primary_key=True),
|
||||
Column('name', String),
|
||||
Column('fullname', String),
|
||||
Column('premium_user', Boolean, default=False),
|
||||
)
|
||||
|
||||
premium_members = select(users).where(users.c.premium_user == True)
|
||||
# sqlalchemy 1.3:
|
||||
# premium_members = select([users]).where(users.c.premium_user == True)
|
||||
create_view('premium_users', premium_members, metadata)
|
||||
|
||||
metadata.create_all(engine) # View is created at this point
|
||||
|
||||
"""
|
||||
table = create_table_from_selectable(
|
||||
name=name,
|
||||
selectable=selectable,
|
||||
metadata=None
|
||||
)
|
||||
|
||||
sa.event.listen(metadata, 'after_create', CreateView(name, selectable))
|
||||
|
||||
@sa.event.listens_for(metadata, 'after_create')
|
||||
def create_indexes(target, connection, **kw):
|
||||
for idx in table.indexes:
|
||||
idx.create(connection)
|
||||
|
||||
sa.event.listen(
|
||||
metadata,
|
||||
'before_drop',
|
||||
DropView(name, cascade=cascade_on_drop)
|
||||
)
|
||||
return table
|
||||
|
||||
|
||||
class RefreshMaterializedView(Executable, ClauseElement):
|
||||
inherit_cache = True
|
||||
|
||||
def __init__(self, name, concurrently):
|
||||
self.name = name
|
||||
self.concurrently = concurrently
|
||||
|
||||
|
||||
@compiler.compiles(RefreshMaterializedView)
|
||||
def compile_refresh_materialized_view(element, compiler):
|
||||
return 'REFRESH MATERIALIZED VIEW {concurrently}{name}'.format(
|
||||
concurrently='CONCURRENTLY ' if element.concurrently else '',
|
||||
name=compiler.dialect.identifier_preparer.quote(element.name),
|
||||
)
|
||||
|
||||
|
||||
def refresh_materialized_view(session, name, concurrently=False):
|
||||
""" Refreshes an already existing materialized view
|
||||
|
||||
:param session: An SQLAlchemy Session instance.
|
||||
:param name: The name of the materialized view to refresh.
|
||||
:param concurrently:
|
||||
Optional flag that causes the ``CONCURRENTLY`` parameter
|
||||
to be specified when the materialized view is refreshed.
|
||||
"""
|
||||
# Since session.execute() bypasses autoflush, we must manually flush in
|
||||
# order to include newly-created/modified objects in the refresh.
|
||||
session.flush()
|
||||
session.execute(RefreshMaterializedView(name, concurrently))
|
Loading…
Add table
Add a link
Reference in a new issue