# -*- coding: utf-8; -*-
################################################################################
#
# Rattail -- Retail Software Framework
# Copyright © 2010-2023 Lance Edgar
#
# This file is part of Rattail.
#
# Rattail is free software: you can redistribute it and/or modify it under the
# terms of the GNU General Public License as published by the Free Software
# Foundation, either version 3 of the License, or (at your option) any later
# version.
#
# Rattail is distributed in the hope that it will be useful, but WITHOUT ANY
# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
# FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
# details.
#
# You should have received a copy of the GNU General Public License along with
# Rattail. If not, see <http://www.gnu.org/licenses/>.
#
################################################################################
"""
Cache Helpers
"""
import logging
from collections import OrderedDict
from sqlalchemy.orm import joinedload
from rattail.core import Object
from rattail.db import model
log = logging.getLogger(__name__)
class CacheKeyNotSupported(Exception):
"""
Special error which a model cacher should raise when generating the cache
key for a given record/object, but when it is discovered the object does
not have sufficient data to generate the key.
"""
class ModelCacher(object):
"""
Generic model data caching class.
"""
def __init__(self, session, model_class, key='uuid',
query=None, order_by=None, query_options=None, normalizer=None,
expect_duplicates=False, omit_duplicates=False,
use_lists=False, message=None):
self.session = session
self.model_class = model_class
self.key = key
self.duplicate_keys = set()
self.expect_duplicates = expect_duplicates
self.omit_duplicates = omit_duplicates
self._query = query
self.order_by = order_by
self.query_options = query_options
if normalizer is None:
self.normalize = lambda d: d
else:
self.normalize = normalizer
self.use_lists = use_lists
self.message = message
if not self.message:
self.message = "Caching {} data".format(self.model_name)
@property
def model_name(self):
return self.model_class.__name__
def query(self):
q = self._query or self.session.query(self.model_class)
if self.order_by:
q = q.order_by(self.order_by)
if self.query_options:
for option in self.query_options:
q = q.options(option)
return q
def get_cache(self, progress):
self.instances = OrderedDict()
query = self.query()
count = query.count()
if not count:
return self.instances
prog = None
if progress:
prog = progress(self.message, count)
for i, instance in enumerate(query, 1):
self.cache_instance(instance)
if prog:
prog.update(i)
if prog:
prog.destroy()
if self.duplicate_keys:
logger = log.debug if self.expect_duplicates else log.warning
logger("found %s duplicated keys in cache for %s",
len(self.duplicate_keys), self.model_name)
if self.omit_duplicates:
for key in self.duplicate_keys:
log.debug("removing duplicated key from cache: {}".format(repr(key)))
del self.instances[key]
return self.instances
def get_key(self, instance, normalized):
if callable(self.key):
return self.key(instance, normalized)
if isinstance(self.key, str):
return getattr(instance, self.key)
return tuple(getattr(instance, k) for k in self.key)
def cache_instance(self, instance):
normalized = self.normalize(instance)
try:
key = self.get_key(instance, normalized)
except CacheKeyNotSupported:
# this means the object doesn't belong in our cache
return
if self.use_lists:
self.instances.setdefault(key, []).append(normalized)
else:
if key not in self.instances:
self.instances[key] = normalized
else:
self.duplicate_keys.add(key)
if not self.omit_duplicates:
log.debug("cache already contained key, but overwriting: {}".format(repr(key)))
self.instances[key] = normalized
[docs]
def cache_model(session, model_class, key='uuid', progress=None, **kwargs):
"""
Convenience function for fetching a cache of data for the given model.
"""
cacher = ModelCacher(session, model_class, key, **kwargs)
return cacher.get_cache(progress)
class DataCacher(Object):
def __init__(self, session=None, **kwargs):
super(DataCacher, self).__init__(session=session, **kwargs)
@property
def class_(self):
raise NotImplementedError
@property
def name(self):
return self.class_.__name__
def query(self):
return self.session.query(self.class_)
def get_cache(self, progress):
self.instances = {}
query = self.query()
count = query.count()
if not count:
return self.instances
prog = None
if progress:
prog = progress("Caching {0} records".format(self.name), count)
cancel = False
for i, instance in enumerate(query, 1):
self.cache_instance(instance)
if prog and not prog.update(i):
cancel = True
break
if prog:
prog.destroy()
if cancel:
session.close()
return None
return self.instances
class DepartmentCacher(DataCacher):
class_ = model.Department
def cache_instance(self, dept):
self.instances[dept.number] = dept
class SubdepartmentCacher(DataCacher):
class_ = model.Subdepartment
def cache_instance(self, subdept):
self.instances[subdept.number] = subdept
class CategoryCacher(DataCacher):
class_ = model.Category
def cache_instance(self, category):
self.instances[category.code] = category
def cache_categories(session, progress=None):
"""
Return a dictionary of all :class:`rattail.db.model.Category` instances,
keyed by :attr:`number`.
"""
cacher = CategoryCacher(session=session)
return cacher.get_cache(progress)
class FamilyCacher(DataCacher):
class_ = model.Family
def cache_instance(self, family):
self.instances[family.code] = family
class ReportCodeCacher(DataCacher):
class_ = model.ReportCode
def cache_instance(self, report_code):
self.instances[report_code.code] = report_code
class BrandCacher(DataCacher):
class_ = model.Brand
def cache_instance(self, brand):
self.instances[brand.name] = brand
class VendorCacher(DataCacher):
class_ = model.Vendor
def cache_instance(self, vend):
self.instances[vend.id] = vend
class ProductCacher(DataCacher):
class_ = model.Product
with_costs = False
def query(self):
q = self.session.query(model.Product)
if self.with_costs:
q = q.options(joinedload(model.Product.costs))
q = q.options(joinedload(model.Product.cost))
return q
def cache_instance(self, prod):
self.instances[prod.upc] = prod
def get_product_cache(session, with_costs=False, progress=None):
"""
Cache the full product set by UPC.
Returns a dictionary of all existing products, keyed by
:attr:`rattail.Product.upc`.
"""
cacher = ProductCacher(session=session, with_costs=with_costs)
return cacher.get_cache(progress)
class CustomerGroupCacher(DataCacher):
class_ = model.CustomerGroup
def cache_instance(self, group):
self.instances[group.id] = group
class CustomerCacher(DataCacher):
class_ = model.Customer
def cache_instance(self, customer):
self.instances[customer.id] = customer