Skip to content

Commit

Permalink
add a couple of missed custom key types hooks (getredash#5014)
Browse files Browse the repository at this point in the history
  • Loading branch information
Omer Lachish authored and andrewdever committed Oct 5, 2020
1 parent 4089050 commit 9838a30
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion redash/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -882,7 +882,7 @@ class Favorite(TimestampMixin, db.Model):
object_id = Column(key_type("Favorite"))
object = generic_relationship(object_type, object_id)

user_id = Column(db.Integer, db.ForeignKey("users.id"))
user_id = Column(key_type("User"), db.ForeignKey("users.id"))
user = db.relationship(User, backref="favorites")

__tablename__ = "favorites"
Expand Down
10 changes: 5 additions & 5 deletions redash/models/changes.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from sqlalchemy.inspection import inspect
from sqlalchemy_utils.models import generic_repr

from .base import GFKBase, db, Column, primary_key
from .base import GFKBase, db, Column, primary_key, key_type
from .types import PseudoJSON


Expand All @@ -10,7 +10,7 @@ class Change(GFKBase, db.Model):
id = primary_key("Change")
# 'object' defined in GFKBase
object_version = Column(db.Integer, default=0)
user_id = Column(db.Integer, db.ForeignKey("users.id"))
user_id = Column(key_type("User"), db.ForeignKey("users.id"))
user = db.relationship("User", backref="changes")
change = Column(PseudoJSON)
created_at = Column(db.DateTime(True), default=db.func.now())
Expand Down Expand Up @@ -57,15 +57,15 @@ def __init__(self, *a, **kw):
def prep_cleanvalues(self):
self.__dict__["_clean_values"] = {}
for attr in inspect(self.__class__).column_attrs:
col, = attr.columns
(col,) = attr.columns
# 'query' is col name but not attr name
self._clean_values[col.name] = None

def __setattr__(self, key, value):
if self._clean_values is None:
self.prep_cleanvalues()
for attr in inspect(self.__class__).column_attrs:
col, = attr.columns
(col,) = attr.columns
previous = getattr(self, attr.key, None)
self._clean_values[col.name] = previous

Expand All @@ -76,7 +76,7 @@ def record_changes(self, changed_by):
db.session.flush()
changes = {}
for attr in inspect(self.__class__).column_attrs:
col, = attr.columns
(col,) = attr.columns
if attr.key not in self.skipped_fields:
changes[col.name] = {
"previous": self._clean_values[col.name],
Expand Down

0 comments on commit 9838a30

Please sign in to comment.