Skip to content

Commit

Permalink
Add class wrapper for lazily converting attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
pprkut committed Mar 31, 2019
1 parent 439d4c1 commit f9f2bdd
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 34 deletions.
133 changes: 103 additions & 30 deletions beets/dbcore/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,100 @@ def _get_formatted(self, model, key):
return value


class LazyConvertDict(object):
"""Lazily convert types for attributes fetched from the database
"""

def __init__(self, model_cls):
"""Initialize the object empty
"""
self.data = {}
self.model_cls = model_cls
self._converted = {}

def init(self, data):
"""Set the base data that should be lazily converted
"""
self.data = data

def _convert(self, key, value):
"""Convert the attribute type according the the SQL type
"""
return self.model_cls._type(key).from_sql(value)

def __setitem__(self, key, value):
"""Set an attribute value, assume it's already converted
"""
self._converted[key] = value

def __getitem__(self, key):
"""Get an attribute value, converting the type on demand
if needed
"""
if key in self._converted:
return self._converted[key]
elif key in self.data:
value = self._convert(key, self.data[key])
self._converted[key] = value
return value

def __delitem__(self, key):
"""Delete both converted and base data
"""
if key in self._converted:
del self._converted[key]
if key in self.data:
del self.data[key]

def keys(self):
"""Get a list of available field names for this object.
"""
return list(self._converted.keys()) + list(self.data.keys())

def copy(self):
"""Create a copy of the object.
"""
new = self.__class__(self.model_cls)
new.data = self.data.copy()
new._converted = self._converted.copy()
return new

# Act like a dictionary.

def update(self, values):
"""Assign all values in the given dict.
"""
for key, value in values.items():
self[key] = value

def items(self):
"""Iterate over (key, value) pairs that this object contains.
Computed fields are not included.
"""
for key in self:
yield key, self[key]

def get(self, key, default=None):
"""Get the value for a given key or `default` if it does not
exist.
"""
if key in self:
return self[key]
else:
return default

def __contains__(self, key):
"""Determine whether `key` is an attribute on this object.
"""
return key in self.keys()

def __iter__(self):
"""Iterate over the available field names (excluding computed
fields).
"""
return iter(self.keys())


# Abstract base for model classes.

class Model(object):
Expand Down Expand Up @@ -177,10 +271,8 @@ def __init__(self, db=None, **values):
"""
self._db = db
self._dirty = set()
self._raw_values_fixed = {}
self._raw_values_flex = {}
self._values_fixed = {}
self._values_flex = {}
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)

# Initial contents.
self.update(values)
Expand All @@ -194,10 +286,10 @@ def _awaken(cls, db=None, fixed_values={}, flex_values={}):
ordinary construction are bypassed.
"""
obj = cls(db)
for key, value in fixed_values.items():
obj._raw_values_fixed[key] = value
for key, value in flex_values.items():
obj._raw_values_flex[key] = value

obj._values_fixed.init(fixed_values)
obj._values_flex.init(flex_values)

return obj

def __repr__(self):
Expand Down Expand Up @@ -234,9 +326,7 @@ def copy(self):
"""
new = self.__class__()
new._db = self._db
new._raw_values_fixed = self._raw_values_fixed.copy()
new._values_fixed = self._values_fixed.copy()
new._raw_values_flex = self._raw_values_flex.copy()
new._values_flex = self._values_flex.copy()
new._dirty = self._dirty.copy()
return new
Expand All @@ -262,16 +352,10 @@ def __getitem__(self, key):
elif key in self._fields: # Fixed.
if key in self._values_fixed:
return self._values_fixed[key]
elif key in self._raw_values_fixed:
self._values_fixed[key] = self._type(key).from_sql(self._raw_values_fixed[key])
return self._values_fixed[key]
else:
return self._type(key).null
elif key in self._values_flex: # Flexible.
return self._values_flex[key]
elif key in self._raw_values_flex: # Flexible.
self._values_flex[key] = self._type(key).from_sql(self._raw_values_flex[key])
return self._values_flex[key]
else:
raise KeyError(key)

Expand All @@ -281,12 +365,8 @@ def _setitem(self, key, value):
"""
# Choose where to place the value.
if key in self._fields:
if not key in self._values_fixed and key in self._raw_values_fixed:
self._values_fixed[key] = self._type(key).from_sql(self._raw_values_fixed[key])
source = self._values_fixed
else:
if not key in self._values_flex and key in self._raw_values_flex:
self._values_flex[key] = self._type(key).from_sql(self._raw_values_flex[key])
source = self._values_flex

# If the field has a type, filter the value.
Expand All @@ -311,11 +391,6 @@ def __delitem__(self, key):
"""
if key in self._values_flex: # Flexible.
del self._values_flex[key]
if key in self._raw_values_flex:
del self._raw_values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._raw_values_flex: # Flexible
del self._raw_values_flex[key]
self._dirty.add(key) # Mark for dropping on store.
elif key in self._fields: # Fixed
setattr(self, key, self._type(key).null)
Expand All @@ -329,7 +404,7 @@ def keys(self, computed=False):
`computed` parameter controls whether computed (plugin-provided)
fields are included in the key list.
"""
base_keys = list(self._fields) + list(self._values_flex.keys()) + list(self._raw_values_flex.keys())
base_keys = list(self._fields) + list(self._values_flex.keys())
if computed:
return base_keys + list(self._getters().keys())
else:
Expand Down Expand Up @@ -458,10 +533,8 @@ def load(self):
self._check_db()
stored_obj = self._db._get(type(self), self.id)
assert stored_obj is not None, u"object {0} not in DB".format(self.id)
self._raw_values_fixed = {}
self._values_fixed = {}
self._raw_values_flex = {}
self._values_flex = {}
self._values_fixed = LazyConvertDict(self)
self._values_flex = LazyConvertDict(self)
self.update(dict(stored_obj))
self.clear_dirty()

Expand Down
8 changes: 4 additions & 4 deletions beets/importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,14 +785,14 @@ def reimport_metadata(self, lib):
replaced_album = self.replaced_albums.get(self.album.path)
if replaced_album:
self.album.added = replaced_album.added
self.album.update(replaced_album._raw_values_flex)
self.album.update(replaced_album._values_flex)
self.album.artpath = replaced_album.artpath
self.album.store()
log.debug(
u'Reimported album: added {0}, flexible '
u'attributes {1} from album {2} for {3}',
self.album.added,
replaced_album._raw_values_flex.keys(),
replaced_album._values_flex.keys(),
replaced_album.id,
displayable_path(self.album.path)
)
Expand All @@ -809,11 +809,11 @@ def reimport_metadata(self, lib):
dup_item.id,
displayable_path(item.path)
)
item.update(dup_item._raw_values_flex)
item.update(dup_item._values_flex)
log.debug(
u'Reimported item flexible attributes {0} '
u'from item {1} for {2}',
dup_item._raw_values_flex.keys(),
dup_item._values_flex.keys(),
dup_item.id,
displayable_path(item.path)
)
Expand Down

0 comments on commit f9f2bdd

Please sign in to comment.