forked from Raiza.dev/EliteBot
185 lines
6.2 KiB
Python
185 lines
6.2 KiB
Python
from collections.abc import Iterable
|
|
|
|
import sqlalchemy as sa
|
|
from sqlalchemy.ext.hybrid import hybrid_property
|
|
from sqlalchemy.orm import attributes, class_mapper, ColumnProperty
|
|
from sqlalchemy.orm.interfaces import MapperProperty, PropComparator
|
|
from sqlalchemy.orm.session import _state_session
|
|
from sqlalchemy.util import set_creation_order
|
|
|
|
from .exceptions import ImproperlyConfigured
|
|
from .functions import identity
|
|
from .functions.orm import _get_class_registry
|
|
|
|
|
|
class GenericAttributeImpl(attributes.ScalarAttributeImpl):
|
|
def get(self, state, dict_, passive=attributes.PASSIVE_OFF):
|
|
if self.key in dict_:
|
|
return dict_[self.key]
|
|
|
|
# Retrieve the session bound to the state in order to perform
|
|
# a lazy query for the attribute.
|
|
session = _state_session(state)
|
|
if session is None:
|
|
# State is not bound to a session; we cannot proceed.
|
|
return None
|
|
|
|
# Find class for discriminator.
|
|
# TODO: Perhaps optimize with some sort of lookup?
|
|
discriminator = self.get_state_discriminator(state)
|
|
target_class = _get_class_registry(state.class_).get(discriminator)
|
|
|
|
if target_class is None:
|
|
# Unknown discriminator; return nothing.
|
|
return None
|
|
|
|
id = self.get_state_id(state)
|
|
|
|
try:
|
|
target = session.get(target_class, id)
|
|
except AttributeError:
|
|
# sqlalchemy 1.3
|
|
target = session.query(target_class).get(id)
|
|
|
|
# Return found (or not found) target.
|
|
return target
|
|
|
|
def get_state_discriminator(self, state):
|
|
discriminator = self.parent_token.discriminator
|
|
if isinstance(discriminator, hybrid_property):
|
|
return getattr(state.obj(), discriminator.__name__)
|
|
else:
|
|
return state.attrs[discriminator.key].value
|
|
|
|
def get_state_id(self, state):
|
|
# Lookup row with the discriminator and id.
|
|
return tuple(state.attrs[id.key].value for id in self.parent_token.id)
|
|
|
|
def set(self, state, dict_, initiator,
|
|
passive=attributes.PASSIVE_OFF,
|
|
check_old=None,
|
|
pop=False):
|
|
|
|
# Set us on the state.
|
|
dict_[self.key] = initiator
|
|
|
|
if initiator is None:
|
|
# Nullify relationship args
|
|
for id in self.parent_token.id:
|
|
dict_[id.key] = None
|
|
dict_[self.parent_token.discriminator.key] = None
|
|
else:
|
|
# Get the primary key of the initiator and ensure we
|
|
# can support this assignment.
|
|
class_ = type(initiator)
|
|
mapper = class_mapper(class_)
|
|
|
|
pk = mapper.identity_key_from_instance(initiator)[1]
|
|
|
|
# Set the identifier and the discriminator.
|
|
discriminator = class_.__name__
|
|
|
|
for index, id in enumerate(self.parent_token.id):
|
|
dict_[id.key] = pk[index]
|
|
dict_[self.parent_token.discriminator.key] = discriminator
|
|
|
|
|
|
class GenericRelationshipProperty(MapperProperty):
|
|
"""A generic form of the relationship property.
|
|
|
|
Creates a 1 to many relationship between the parent model
|
|
and any other models using a descriminator (the table name).
|
|
|
|
:param discriminator
|
|
Field to discriminate which model we are referring to.
|
|
:param id:
|
|
Field to point to the model we are referring to.
|
|
"""
|
|
|
|
def __init__(self, discriminator, id, doc=None):
|
|
super().__init__()
|
|
self._discriminator_col = discriminator
|
|
self._id_cols = id
|
|
self._id = None
|
|
self._discriminator = None
|
|
self.doc = doc
|
|
|
|
set_creation_order(self)
|
|
|
|
def _column_to_property(self, column):
|
|
if isinstance(column, hybrid_property):
|
|
attr_key = column.__name__
|
|
for key, attr in self.parent.all_orm_descriptors.items():
|
|
if key == attr_key:
|
|
return attr
|
|
else:
|
|
for attr in self.parent.attrs.values():
|
|
if isinstance(attr, ColumnProperty):
|
|
if attr.columns[0].name == column.name:
|
|
return attr
|
|
|
|
def init(self):
|
|
def convert_strings(column):
|
|
if isinstance(column, str):
|
|
return self.parent.columns[column]
|
|
return column
|
|
|
|
self._discriminator_col = convert_strings(self._discriminator_col)
|
|
self._id_cols = convert_strings(self._id_cols)
|
|
|
|
if isinstance(self._id_cols, Iterable):
|
|
self._id_cols = list(map(convert_strings, self._id_cols))
|
|
else:
|
|
self._id_cols = [self._id_cols]
|
|
|
|
self.discriminator = self._column_to_property(self._discriminator_col)
|
|
|
|
if self.discriminator is None:
|
|
raise ImproperlyConfigured(
|
|
'Could not find discriminator descriptor.'
|
|
)
|
|
|
|
self.id = list(map(self._column_to_property, self._id_cols))
|
|
|
|
class Comparator(PropComparator):
|
|
def __init__(self, prop, parentmapper):
|
|
self.property = prop
|
|
self._parententity = parentmapper
|
|
|
|
def __eq__(self, other):
|
|
discriminator = type(other).__name__
|
|
q = self.property._discriminator_col == discriminator
|
|
other_id = identity(other)
|
|
for index, id in enumerate(self.property._id_cols):
|
|
q &= id == other_id[index]
|
|
return q
|
|
|
|
def __ne__(self, other):
|
|
return ~(self == other)
|
|
|
|
def is_type(self, other):
|
|
mapper = sa.inspect(other)
|
|
# Iterate through the weak sequence in order to get the actual
|
|
# mappers
|
|
class_names = [other.__name__]
|
|
class_names.extend([
|
|
submapper.class_.__name__
|
|
for submapper in mapper._inheriting_mappers
|
|
])
|
|
|
|
return self.property._discriminator_col.in_(class_names)
|
|
|
|
def instrument_class(self, mapper):
|
|
attributes.register_attribute(
|
|
mapper.class_,
|
|
self.key,
|
|
comparator=self.Comparator(self, mapper),
|
|
parententity=mapper,
|
|
doc=self.doc,
|
|
impl_class=GenericAttributeImpl,
|
|
parent_token=self
|
|
)
|
|
|
|
|
|
def generic_relationship(*args, **kwargs):
|
|
return GenericRelationshipProperty(*args, **kwargs)
|