Source code for sqlalchemy_continuum.model_builder

from copy import copy
import six
import sqlalchemy as sa
from sqlalchemy.ext.declarative import declared_attr
from sqlalchemy.orm import column_property
from sqlalchemy_utils.functions import get_declarative_base

from .utils import adapt_columns, option
from .version import VersionClassBase


def find_closest_versioned_parent(manager, model):
    """
    Finds the closest versioned parent for current parent model.
    """
    for class_ in model.__bases__:
        if class_ in manager.version_class_map:
            return manager.version_class_map[class_]

def versioned_parents(manager, model):
    """
    Finds all versioned ancestors for current parent model.
    """
    for class_ in model.__mro__:
        if class_ in manager.version_class_map:
            yield manager.version_class_map[class_]


def get_base_class(manager, model):
    """
    Returns all base classes for history model.
    """
    return (
        option(model, 'base_classes')
        or
        (get_declarative_base(model), )
    )


def version_base(manager, parent_cls, base_class_factory=None):
    if base_class_factory is None:
        base_class_factory = get_base_class

    VersionBase = find_closest_versioned_parent(manager, parent_cls)

    if not VersionBase:
        VersionBase = type(
            'VersionBase',
            (base_class_factory(manager, parent_cls) + (VersionClassBase, )),
            {'__abstract__': True}
        )

    return VersionBase


def copy_mapper_args(model):
    args = {}
    if hasattr(model, '__mapper_args__'):
        arg_names = (
            'with_polymorphic',
            'polymorphic_identity',
            'concrete'
        )
        for arg in arg_names:
            if arg in model.__mapper_args__:
                args[arg] = (
                    model.__mapper_args__[arg]
                )

        if 'order_by' in model.__mapper_args__:
            arg = model.__mapper_args__['order_by']
            # Only allow string based order_by reflection to version
            # classes.
            if isinstance(arg, six.string_types):
                args['order_by'] = arg

        if 'polymorphic_on' in model.__mapper_args__:
            column = model.__mapper_args__['polymorphic_on']
            if isinstance(column, six.string_types):
                args['polymorphic_on'] = column
            else:
                args['polymorphic_on'] = column.key
    return args


[docs]class ModelBuilder(object): """ VersionedModelBuilder handles the building of Version models based on parent table attributes and versioning configuration. """ def __init__(self, versioning_manager, model): """ :param versioning_manager: VersioningManager object :param model: SQLAlchemy declarative model object that acts as a parent for the built version model """ self.manager = versioning_manager self.model = model
[docs] def build_parent_relationship(self): """ Builds a relationship between currently built version class and parent class (the model whose history the currently build version class represents). """ conditions = [] foreign_keys = [] for key, column in sa.inspect(self.model).columns.items(): if column.primary_key: conditions.append( getattr(self.model, key) == getattr(self.version_class, key) ) foreign_keys.append( getattr(self.version_class, key) ) # We need to check if versions relation was already set for parent # class. if not hasattr(self.model, 'versions'): self.model.versions = sa.orm.relationship( self.version_class, primaryjoin=sa.and_(*conditions), foreign_keys=foreign_keys, order_by=lambda: getattr( self.version_class, option(self.model, 'transaction_column_name') ), lazy='dynamic', backref=sa.orm.backref( 'version_parent' ), viewonly=True )
[docs] def build_transaction_relationship(self, tx_class): """ Builds a relationship between currently built version class and Transaction class. :param tx_class: Transaction class """ # Only define transaction relation if it doesn't already exist in # parent class. transaction_column = getattr( self.version_class, option(self.model, 'transaction_column_name') ) if not hasattr(self.version_class, 'transaction'): self.version_class.transaction = sa.orm.relationship( tx_class, primaryjoin=tx_class.id == transaction_column, foreign_keys=[transaction_column], )
[docs] def base_classes(self): """ Returns all base classes for history model. """ return (version_base(self.manager, self.model), )
[docs] def inheritance_args(self, cls, version_table, table): """ Return mapper inheritance args for currently built history model. """ args = {} if not sa.inspect(self.model).single: parent = find_closest_versioned_parent( self.manager, self.model ) if parent: # The version classes do not contain foreign keys, hence we # need to map inheritance condition manually for classes that # use joined table inheritance if parent.__table__.name != table.name: mapper = sa.inspect(self.model) inherit_condition = adapt_columns( mapper.inherit_condition ) tx_column_name = self.manager.options[ 'transaction_column_name' ] args['inherit_condition'] = sa.and_( inherit_condition, getattr(parent.__table__.c, tx_column_name) == getattr(cls.__table__.c, tx_column_name) ) args['inherit_foreign_keys'] = [ version_table.c[column.key] for column in sa.inspect(self.model).columns if column.primary_key ] args.update(copy_mapper_args(self.model)) return args
def get_inherited_denormalized_columns(self, table): parent_models = list(versioned_parents(self.manager, self.model)) mapper = sa.inspect(self.model) args = {} if parent_models and not (mapper.single or mapper.concrete): columns = [ self.manager.option(self.model, 'operation_type_column_name'), self.manager.option(self.model, 'transaction_column_name') ] if self.manager.option(self.model, 'strategy') == 'validity': columns.append( self.manager.option( self.model, 'end_transaction_column_name' ) ) for column in columns: args[column] = column_property( table.c[column], *[m.__table__.c[column] for m in parent_models] ) return args
[docs] def build_model(self, table): """ Build history model class. """ args = {} @declared_attr def mapper_args(cls): mapper_args = {} mapper_args.update(self.inheritance_args( cls, table, self.model.__table__) ) return mapper_args args['__mapper_args__'] = mapper_args args['__versioning_manager__'] = self.manager args['__version_parent__'] = self.model parent = find_closest_versioned_parent(self.manager, self.model) if not parent or parent.__table__.name != table.name: args['__table__'] = table args.update(self.get_inherited_denormalized_columns(table)) if self.manager.options.get('use_module_name', True): name = '%s%sVersion' % ( self.model.__module__.title().replace('.', ''), self.model.__name__ ) else: name = '%sVersion' % (self.model.__name__,) return type(name, self.base_classes(), args)
def __call__(self, table, tx_class): """ Build history model and relationships to parent model, transaction log model. """ # versioned attributes need to be copied for each child class, # otherwise each child class would share the same __versioned__ # option dict self.model.__versioned__ = copy(self.model.__versioned__) self.model.__versioning_manager__ = self.manager self.version_class = self.build_model(table) self.build_parent_relationship() self.build_transaction_relationship(tx_class) return self.version_class