# -*- coding: utf-8; -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2023 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Database Utilities
"""
import re
import pprint
import logging
import warnings
import sqlalchemy as sa
from sqlalchemy import orm
from sqlalchemy.ext.associationproxy import ASSOCIATION_PROXY
# TODO: Deprecate/remove these imports.
from rattail.db.config import get_engines, get_default_engine, configure_session
log = logging.getLogger(__name__)
[docs]
class CounterMagic(object):
"""
Provides magic counter values, to simulate PostgreSQL sequence.
"""
def __init__(self, config):
self.config = config
self.metadata = sa.MetaData()
[docs]
def next_value(self, session, key):
"""
Increment and return the next counter value for given key.
"""
engine = session.bind
table = sa.Table('counter_{}'.format(key), self.metadata,
sa.Column('value', sa.Integer(), primary_key=True))
table.create(engine, checkfirst=True)
with engine.begin() as cxn:
result = cxn.execute(table.insert())
return result.lastrowid
[docs]
class QuerySequence(object):
"""
Simple wrapper for a SQLAlchemy (or Django, or other?) query, to make it
sort of behave like a normal sequence, as much as needed to e.g. make an
importer happy.
"""
def __init__(self, query):
self.query = query
def __len__(self):
try:
return len(self.query)
except TypeError:
return self.query.count()
def __iter__(self):
return iter(self.query)
[docs]
def short_session(
session=None,
Session=None,
commit=False,
factory=None,
config=None):
"""
Compatibility wrapper around
:class:`wuttjamaican:wuttjamaican.db.sess.short_session`.
Note that this wrapper is a function whereas the upsream version
is a proper context manager (class). So calling this function
will return a new instance of the upsream class.
You should always specify keyword arguments when calling this
function, since the arg order is different between this function
and the upstream class. And note that this function will
eventually be deprecated and removed, so new code should call
upstream directly.
"""
from wuttjamaican.db import short_session
warnings.warn("rattail.db.util.short_session() is deprecated; "
"please use wuttjamaican.db.short_session() instead",
DeprecationWarning, stacklevel=2)
if not factory and Session:
warnings.warn("passing a 'Session' kwarg is deprecated; "
"please pass 'factory' instead",
DeprecationWarning, stacklevel=2)
factory = Session
if not session and not factory and not config:
from rattail.db import Session
factory = Session
return short_session(config=config, factory=factory, session=session, commit=commit)
[docs]
def get_fieldnames(config, obj, columns=True, proxies=True,
relations=False):
"""
Produce a simple list of fieldnames for the given class,
reflecting its table columns as well as any association proxies,
and optionally, relationships.
:param obj: Either a class or instance of a class, which derives
from the base model class.
:param columns: Whether or not to include simple columns.
:param relations: Whether or not to include fields which represent
relationships to other models. If ``False`` (the default) then
only "simple" fields will be included.
:param proxies: Whether or not to include association proxy fields.
"""
if isinstance(obj, type):
cls = obj
else:
cls = obj.__class__
mapper = orm.class_mapper(cls)
fields = []
# columns + relations
prop_classes = []
if columns:
prop_classes.append(orm.ColumnProperty)
if relations:
prop_classes.append(orm.RelationshipProperty)
if prop_classes:
prop_classes = tuple(prop_classes)
fields.extend([prop.key for prop in mapper.iterate_properties
if isinstance(prop, prop_classes)
and not prop.key.startswith('_')
and prop.key != 'versions'])
# proxies
if proxies:
for key, desc in sa.inspect(cls).all_orm_descriptors.items():
if desc.extension_type == ASSOCIATION_PROXY:
# must avoid association proxies which in turn use
# relationships, unless those are wanted by caller
if not relations:
# TODO: this probably needs help, i stumbled thru it..
prop = sa.inspect(desc.for_class(cls).target_class)\
.get_property(desc.value_attr)
if isinstance(prop, orm.RelationshipProperty):
continue
fields.append(key)
return fields
[docs]
def maxlen(attr):
"""
Return the maximum length for the given attribute.
"""
if len(attr.property.columns) == 1:
type_ = attr.property.columns[0].type
return getattr(type_, 'length', None)
[docs]
def make_topo_sortkey(model, metadata=None):
"""
Returns a function suitable for use as a ``key`` kwarg to a standard Python
sorting call. This key function will expect a single class mapper and
return a sequence number associated with that model. The sequence is
determined by SQLAlchemy's topological table sorting.
"""
if metadata is None:
metadata = model.Base.metadata
tables = {}
for i, table in enumerate(metadata.sorted_tables, 1):
tables[table.name] = i
# log.debug("topo sortkeys for '{}' will be:\n{}".format(model.__name__, pprint.pformat(
# [(i, name) for name, i in sorted(tables.items(), key=lambda t: t[1])])))
def sortkey(name):
if hasattr(model, name):
mapper = orm.class_mapper(getattr(model, name))
return tuple(tables[t.name] for t in mapper.tables)
else:
return tuple()
return sortkey
[docs]
def make_full_description(brand_name, description, size):
"""
Combine the given field values into a complete description.
"""
fields = [
brand_name or '',
description or '',
size or '']
fields = [f.strip() for f in fields if f.strip()]
return ' '.join(fields)
[docs]
def normalize_full_name(first_name, last_name):
"""
Normalize the given first and last name to a "full" name value. The
fallback return value is an empty string.
"""
first_name = (first_name or '').strip()
last_name = (last_name or '').strip()
if first_name and last_name:
return "{} {}".format(first_name, last_name)
if first_name:
return first_name
if last_name:
return last_name
return ''
##############################
# phone number validation
##############################
[docs]
class PhoneValidator(object):
"""
Simple validator, used to ensure a phone number matches the general
expected pattern.
"""
# NOTE: this was stolen from FormEncode
_phoneRE = re.compile(r'^\s*(?:1-)?(\d\d\d)[\- \.]?(\d\d\d)[\- \.]?(\d\d\d\d)(?:\s*ext\.?\s*(\d+))?\s*$', re.I)
def __init__(self, error=False):
self.error = error
def validate(self, number):
if number:
if not self._phoneRE.search(number):
raise ValueError("Phone number is not valid")
return number
[docs]
def validate_phone_number(number, error=False):
"""
Validate a single phone number.
"""
validator = PhoneValidator(error=error)
return validator.validate(number)
[docs]
def normalize_phone_number(number):
"""
Normalize a phone number to digits only.
"""
if number is not None:
return re.sub(r'\D', '', number)