From 0737bda552f0ac0066651920db7548f0b4cf1450 Mon Sep 17 00:00:00 2001 From: Allen Short Date: Wed, 23 Nov 2016 12:35:18 -0600 Subject: [PATCH] test_models passes --- redash/models.py | 297 +++++++++++++---------- tests/factories.py | 22 +- tests/tasks/test_refresh_queries.py | 2 +- tests/test_models.py | 355 +++++++++++++++------------- 4 files changed, 365 insertions(+), 311 deletions(-) diff --git a/redash/models.py b/redash/models.py index 1922d041e6..566a3b7097 100644 --- a/redash/models.py +++ b/redash/models.py @@ -108,17 +108,18 @@ def __setattr__(self, key, value): super(ChangeTrackingMixin, self).__setattr__(key, value) - def record_changes(self, session, changed_by): + def record_changes(self, changed_by): changes = {} for k, v in self._clean_values.iteritems(): if k not in self.skipped_fields: changes[k] = {'previous': v, 'current': getattr(self, k)} - session.add(Change(object_type=self.__class__.__tablename__, - object_id=self.id, + db.session.flush() + db.session.add(Change(object_type=self.__class__.__tablename__, + object=self, object_version=self.version, - user_id=changed_by.id, + user=changed_by, change=changes)) - session.add(self) + class ConflictDetectedError(Exception): @@ -127,7 +128,7 @@ class ConflictDetectedError(Exception): class BelongsToOrgMixin(object): @classmethod def get_by_id_and_org(cls, object_id, org): - return cls.query.filter(cls.id == object_id, cls.org == org).first() + return cls.query.filter(cls.id == object_id, cls.org == org).one_or_none() class PermissionsCheckMixin(object): @@ -180,6 +181,7 @@ class Organization(TimestampMixin, db.Model): name = Column(db.String(255)) slug = Column(db.String(255), unique=True) settings = Column(PseudoJSON) + groups = db.relationship("Group", lazy="dynamic") __tablename__ = 'organizations' @@ -220,7 +222,7 @@ class Group(db.Model, BelongsToOrgMixin): id = Column(db.Integer, primary_key=True) org_id = Column(db.Integer, db.ForeignKey('organizations.id')) - org = db.relationship(Organization, backref="groups") + org = db.relationship(Organization, back_populates="groups") type = Column(db.String(255), default=REGULAR_GROUP) name = Column(db.String(100)) permissions = Column(postgresql.ARRAY(db.String(255)), @@ -267,6 +269,7 @@ class User(TimestampMixin, db.Model, BelongsToOrgMixin, UserMixin, PermissionsCh name = Column(db.String(320)) email = Column(db.String(320)) password_hash = Column(db.String(128), nullable=True) + #XXX replace with association table group_ids = Column('groups', postgresql.ARRAY(db.Integer), nullable=True) api_key = Column(db.String(40), default=lambda: generate_token(40), @@ -338,8 +341,8 @@ def verify_password(self, password): def update_group_assignments(self, group_names): groups = Group.find_by_name(self.org, group_names) groups.append(self.org.default_group) - self.groups = map(lambda g: g.id, groups) - self.save() + self.group_ids = [g.id for g in groups] + db.session.add(self) def has_access(self, obj, access_type): return AccessPermission.exists(obj, access_type, grantee=self) @@ -368,6 +371,7 @@ class DataSource(BelongsToOrgMixin, db.Model): scheduled_queue_name = Column(db.String(255), default="scheduled_queries") created_at = Column(db.DateTime(True), default=db.func.now()) + data_source_groups = db.relationship("DataSourceGroup", back_populates="data_source") __tablename__ = 'data_sources' __table_args__ = (db.Index('data_sources_org_id_name', 'org_id', 'name'),) @@ -468,16 +472,19 @@ def all(cls, org, groups=None): return data_sources + #XXX examine call sites to see if a regular SQLA collection would work better @property def groups(self): - groups = DataSourceGroup.select().where(DataSourceGroup.data_source==self) + groups = db.session.query(DataSourceGroup).filter( + DataSourceGroup.data_source == self) return dict(map(lambda g: (g.group_id, g.view_only), groups)) class DataSourceGroup(db.Model): + #XXX drop id, use datasource/group as PK id = Column(db.Integer, primary_key=True) data_source_id = Column(db.Integer, db.ForeignKey("data_sources.id")) - data_source = db.relationship(DataSource) + data_source = db.relationship(DataSource, back_populates="data_source_groups") group_id = Column(db.Integer, db.ForeignKey("groups.id")) group = db.relationship(Group, backref="data_sources") view_only = Column(db.Boolean, default=False) @@ -514,8 +521,9 @@ def to_dict(self): def unused(cls, days=7): age_threshold = datetime.datetime.now() - datetime.timedelta(days=days) - unused_results = cls.select().where(Query.id == None, cls.retrieved_at < age_threshold)\ - .join(Query, join_type=peewee.JOIN_LEFT_OUTER) + unused_results = (db.session.query(QueryResult).filter( + Query.id == None, QueryResult.retrieved_at < age_threshold) + .outerjoin(Query)) return unused_results @@ -524,35 +532,41 @@ def get_latest(cls, data_source, query, max_age=0): query_hash = utils.gen_query_hash(query) if max_age == -1: - query = cls.select().where(cls.query_hash == query_hash, - cls.data_source == data_source).order_by(cls.retrieved_at.desc()) + q = db.session.query(QueryResult).filter( + cls.query_hash == query_hash, + cls.data_source == data_source).order_by( + QueryResult.retrieved_at.desc()) else: - query = cls.select().where(cls.query_hash == query_hash, cls.data_source == data_source, - peewee.SQL("retrieved_at at time zone 'utc' + interval '%s second' >= now() at time zone 'utc'", - max_age)).order_by(cls.retrieved_at.desc()) + q = db.session.query(QueryResult).filter( + QueryResult.query_hash == query_hash, + QueryResult.data_source == data_source, + db.func.timezone('utc', QueryResult.retrieved_at) + + datetime.timedelta(seconds=max_age) >= + db.func.timezone('utc', db.func.now()) + ).order_by(QueryResult.retrieved_at.desc()) - return query.first() + return q.first() @classmethod - def store_result(cls, org_id, data_source_id, query_hash, query, data, run_time, retrieved_at): - query_result = cls.create(org=org_id, - query_hash=query_hash, - query=query, - runtime=run_time, - data_source=data_source_id, - retrieved_at=retrieved_at, - data=data) - + def store_result(cls, org, data_source, query_hash, query, data, run_time, retrieved_at): + query_result = cls(org=org, + query_hash=query_hash, + query=query, + runtime=run_time, + data_source=data_source, + retrieved_at=retrieved_at, + data=data) + db.session.add(query_result) logging.info("Inserted query (%s) data; id=%s", query_hash, query_result.id) - sql = "UPDATE queries SET latest_query_data_id = %s WHERE query_hash = %s AND data_source_id = %s RETURNING id" - query_ids = [row[0] for row in db.database.execute_sql(sql, params=(query_result.id, query_hash, data_source_id))] - - # TODO: when peewee with update & returning support is released, we can get back to using this code: - # updated_count = Query.update(latest_query_data=query_result).\ - # where(Query.query_hash==query_hash, Query.data_source==data_source_id).\ - # execute() - + # TODO: Investigate how big an impact this select-before-update makes. + queries = db.session.query(Query).filter( + Query.query_hash == query_hash, + Query.data_source == data_source) + for q in queries: + q.latest_query_data = query_result + db.session.add(q) + query_ids = [q.id for q in queries] logging.info("Updated %s queries with result (%s).", len(query_ids), query_hash) return query_result, query_ids @@ -592,8 +606,6 @@ def generate_query_api_key(ctx): str(ctx.current_parameters['user_id']), ctx.current_parameters['name'])).encode('utf-8')).hexdigest() -def gen_query_hash(ctx): - return utils.gen_query_hash(ctx.current_parameters['query']) class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): id = Column(db.Integer, primary_key=True) @@ -607,14 +619,11 @@ class Query(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model): name = Column(db.String(255)) description = Column(db.String(4096), nullable=True) query = Column(db.Text) - query_hash = Column(db.String(32), - default=gen_query_hash, - onupdate=gen_query_hash) + query_hash = Column(db.String(32)) api_key = Column(db.String(40), default=generate_query_api_key) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User, foreign_keys=[user_id]) - last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True, - onupdate=lambda ctx: ctx.current_parameters['user_id']) + last_modified_by_id = Column(db.Integer, db.ForeignKey('users.id'), nullable=True) last_modified_by = db.relationship(User, backref="modified_queries", foreign_keys=[last_modified_by_id]) is_archived = Column(db.Boolean, default=False, index=True) @@ -665,46 +674,48 @@ def to_dict(self, with_stats=False, with_visualizations=False, with_user=True, w return d def archive(self, user=None): + db.session.add(self) self.is_archived = True self.schedule = None for vis in self.visualizations: for w in vis.widgets: - w.delete_instance() + db.session.delete(w) - for alert in self.alerts: - alert.delete_instance(recursive=True) + for a in self.alerts: + db.session.delete(a) - self.save(changed_by=user) + if user: + self.record_changes(user) @classmethod def all_queries(cls, groups, drafts=False): - q = Query.select(Query, User, QueryResult.retrieved_at, QueryResult.runtime)\ - .join(QueryResult, join_type=peewee.JOIN_LEFT_OUTER)\ - .switch(Query).join(User)\ - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source))\ - .where(Query.is_archived==False)\ - .where(DataSourceGroup.group << groups)\ - .group_by(Query.id, User.id, QueryResult.id, QueryResult.retrieved_at, QueryResult.runtime)\ - .order_by(cls.created_at.desc()) + q = (db.session.query(Query) + .outerjoin(QueryResult) + .join(User, Query.user_id == User.id) + .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter(Query.is_archived == False) + .filter(DataSourceGroup.group_id.in_([g.id for g in groups]))\ + .group_by(Query.id, User.id, QueryResult.id, QueryResult.retrieved_at, QueryResult.runtime) + .order_by(Query.created_at.desc())) if drafts: - q = q.where(Query.name == 'New Query') + q = q.filter(Query.name == 'New Query') else: - q = q.where(Query.name != 'New Query') + q = q.filter(Query.name != 'New Query') return q @classmethod def by_user(cls, user, drafts): - return cls.all_queries(user.groups, drafts).where(Query.user==user) + return cls.all_queries(user.groups, drafts).filter(Query.user == user) @classmethod def outdated_queries(cls): - queries = cls.select(cls, QueryResult.retrieved_at, DataSource)\ - .join(QueryResult)\ - .switch(Query).join(DataSource)\ - .where(cls.schedule != None) + queries = (db.session.query(Query) + .join(QueryResult) + .join(DataSource) + .filter(Query.schedule != None)) now = utils.utcnow() outdated_queries = {} @@ -718,38 +729,41 @@ def outdated_queries(cls): @classmethod def search(cls, term, groups): # TODO: This is very naive implementation of search, to be replaced with PostgreSQL full-text-search solution. - - where = (cls.name**u"%{}%".format(term)) | (cls.description**u"%{}%".format(term)) + where = (Query.name.like(u"%{}%".format(term)) | + Query.description.like(u"%{}%".format(term))) if term.isdigit(): - where |= cls.id == term - - where &= cls.is_archived == False + where |= Query.id == term - query_ids = cls.select(peewee.fn.Distinct(cls.id))\ - .join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source)) \ - .where(where) \ - .where(DataSourceGroup.group << groups) - - return cls.select(Query, User).join(User).where(cls.id << query_ids) + where &= Query.is_archived == False + where &= DataSourceGroup.group_id.in_([g.id for g in groups]) + query_ids = ( + db.session.query(Query.id).join( + DataSourceGroup, + Query.data_source_id == DataSourceGroup.data_source_id) + .filter(where)).distinct() + return db.session.query(Query).join(User, Query.user_id == User.id).filter( + Query.id.in_(query_ids)) @classmethod def recent(cls, groups, user_id=None, limit=20): - query = cls.select(Query, User).where(Event.created_at > peewee.SQL("current_date - 7")).\ - join(Event, on=(Query.id == Event.object_id.cast('integer'))). \ - join(DataSourceGroup, on=(Query.data_source==DataSourceGroup.data_source)). \ - switch(Query).join(User).\ - where(Event.action << ('edit', 'execute', 'edit_name', 'edit_description', 'view_source')).\ - where(~(Event.object_id >> None)).\ - where(Event.object_type == 'query'). \ - where(DataSourceGroup.group << groups).\ - where(cls.is_archived == False).\ - group_by(Event.object_id, Query.id, User.id).\ - order_by(peewee.SQL("count(0) desc")) + query = (db.session.query(Query).join(User, Query.user_id == User.id) + .filter(Event.created_at > (db.func.current_date() - 7)) + .join(Event, Query.id == Event.object_id.cast(db.Integer)) + .join(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Event.action.in_(['edit', 'execute', 'edit_name', + 'edit_description', 'view_source']), + Event.object_id != None, + Event.object_type == 'query', + DataSourceGroup.group_id.in_([g.id for g in groups]), + Query.is_archived == False) + .group_by(Event.object_id, Query.id, User.id) + .order_by(db.desc(db.func.count(0)))) if user_id: - query = query.where(Event.user == user_id) + query = query.filter(Event.user_id == user_id) query = query.limit(limit) @@ -773,14 +787,26 @@ def groups(self): def __unicode__(self): return unicode(self.id) +@listens_for(Query.query, 'set') +def gen_query_hash(target, val, oldval, initiator): + target.query_hash = utils.gen_query_hash(val) + +@listens_for(Query.user_id, 'set') +def query_last_modified_by(target, val, oldval, initiator): + target.last_modified_by_id = val + @listens_for(SignallingSession, 'before_flush') -def create_default_visualizations(session, ctx, *a): +def create_defaults(session, ctx, *a): for obj in session.new: if isinstance(obj, Query): session.add(Visualization(query=obj, name="Table", description='', type="TABLE", options="{}")) +@listens_for(ChangeTrackingMixin, 'init') +def create_first_change(obj, args, kwargs): + obj.record_changes(obj.user) + class AccessPermission(GFKBase, db.Model): @@ -891,6 +917,7 @@ class Alert(TimestampMixin, db.Model): user = db.relationship(User, backref='alerts') options = Column(PseudoJSON) state = Column(db.String(255), default=UNKNOWN_STATE) + subscriptions = db.relationship("AlertSubscription", cascade="delete") last_triggered_at = Column(db.DateTime(True), nullable=True) rearm = Column(db.Integer, nullable=True) @@ -974,6 +1001,7 @@ class Dashboard(ChangeTrackingMixin, TimestampMixin, BelongsToOrgMixin, db.Model name = Column(db.String(100)) user_id = Column(db.Integer, db.ForeignKey("users.id")) user = db.relationship(User) + # XXX replace with association table layout = Column(db.Text) dashboard_filters_enabled = Column(db.Boolean, default=False) is_archived = Column(db.Boolean, default=False, index=True) @@ -1037,42 +1065,47 @@ def to_dict(self, with_widgets=False, user=None): } @classmethod - def all(cls, org, groups, user_id): - query = cls.select().\ - join(Widget, peewee.JOIN_LEFT_OUTER, on=(Dashboard.id == Widget.dashboard)). \ - join(Visualization, peewee.JOIN_LEFT_OUTER, on=(Widget.visualization == Visualization.id)). \ - join(Query, peewee.JOIN_LEFT_OUTER, on=(Visualization.query == Query.id)). \ - join(DataSourceGroup, peewee.JOIN_LEFT_OUTER, on=(Query.data_source == DataSourceGroup.data_source)). \ - where(Dashboard.is_archived == False). \ - where((DataSourceGroup.group << groups) | - (Dashboard.user == user_id) | - (~(Widget.dashboard >> None) & (Widget.visualization >> None))). \ - where(Dashboard.org == org). \ - group_by(Dashboard.id) + def all(cls, org, group_ids, user_id): + query = ( + db.session.query(Dashboard) + .outerjoin(Widget) + .outerjoin(Visualization) + .outerjoin(Query) + .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Dashboard.is_archived == False, + (DataSourceGroup.group_id.in_(group_ids) | + (Dashboard.user_id == user_id) | + ((Widget.dashboard != None) & (Widget.visualization == None))), + Dashboard.org == org) + .group_by(Dashboard.id)) return query @classmethod - def recent(cls, org, groups, user_id, for_user=False, limit=20): - query = cls.select().where(Event.created_at > peewee.SQL("current_date - 7")). \ - join(Event, peewee.JOIN_LEFT_OUTER, on=(Dashboard.id == Event.object_id.cast('integer'))). \ - join(Widget, peewee.JOIN_LEFT_OUTER, on=(Dashboard.id == Widget.dashboard)). \ - join(Visualization, peewee.JOIN_LEFT_OUTER, on=(Widget.visualization == Visualization.id)). \ - join(Query, peewee.JOIN_LEFT_OUTER, on=(Visualization.query == Query.id)). \ - join(DataSourceGroup, peewee.JOIN_LEFT_OUTER, on=(Query.data_source == DataSourceGroup.data_source)). \ - where(Event.action << ('edit', 'view')). \ - where(~(Event.object_id >> None)). \ - where(Event.object_type == 'dashboard'). \ - where(Dashboard.is_archived == False). \ - where(Dashboard.org == org). \ - where((DataSourceGroup.group << groups) | - (Dashboard.user == user_id) | - (~(Widget.dashboard >> None) & (Widget.visualization >> None))). \ - group_by(Event.object_id, Dashboard.id). \ - order_by(peewee.SQL("count(0) desc")) + def recent(cls, org, group_ids, user_id, for_user=False, limit=20): + query = (db.session.query(Dashboard) + .outerjoin(Event, Dashboard.id == Event.object_id.cast(db.Integer)) + .outerjoin(Widget) + .outerjoin(Visualization) + .outerjoin(Query) + .outerjoin(DataSourceGroup, Query.data_source_id == DataSourceGroup.data_source_id) + .filter( + Event.created_at > (db.func.current_date() - 7), + Event.action.in_(['edit', 'view']), + Event.object_id != None, + Event.object_type == 'dashboard', + Dashboard.org == org, + Dashboard.is_archived == False, + DataSourceGroup.group_id.in_(group_ids) | + (Dashboard.user_id == user_id) | + ((Widget.dashboard != None) & (Widget.visualization == None))) + .group_by(Event.object_id, Dashboard.id) + .order_by(db.desc(db.func.count(0)))) + if for_user: - query = query.where(Event.user == user_id) + query = query.filter(Event.user_id == user_id) query = query.limit(limit) @@ -1138,7 +1171,7 @@ class Widget(TimestampMixin, db.Model): width = Column(db.Integer) options = Column(db.Text) dashboard_id = Column(db.Integer, db.ForeignKey("dashboards.id"), index=True) - dashboard = db.relationship(Dashboard, backref='widgets') + dashboard = db.relationship(Dashboard) # unused; kept for backward compatability: type = Column(db.String(100), nullable=True) @@ -1169,13 +1202,14 @@ def __unicode__(self): def get_by_id_and_org(cls, widget_id, org): return cls.select(cls, Dashboard).join(Dashboard).where(cls.id == widget_id, Dashboard.org == org).get() - def delete_instance(self, *args, **kwargs): - layout = json.loads(self.dashboard.layout) - layout = map(lambda row: filter(lambda w: w != self.id, row), layout) - layout = filter(lambda row: len(row) > 0, layout) - self.dashboard.layout = json.dumps(layout) - self.dashboard.save() - super(Widget, self).delete_instance(*args, **kwargs) +#XXX produces SQLA warning, replace with association table +@listens_for(Widget, 'before_delete') +def widget_delete(mapper, connection, self): + layout = json.loads(self.dashboard.layout) + layout = map(lambda row: filter(lambda w: w != self.id, row), layout) + layout = filter(lambda row: len(row) > 0, layout) + self.dashboard.layout = json.dumps(layout) + db.session.add(self.dashboard) class Event(db.Model): @@ -1185,6 +1219,7 @@ class Event(db.Model): user_id = Column(db.Integer, db.ForeignKey("users.id"), nullable=True) user = db.relationship(User, backref="events") action = Column(db.String(255)) + # XXX replace with association table object_type = Column(db.String(255)) object_id = Column(db.String(255), nullable=True) additional_properties = Column(db.Text, nullable=True) @@ -1197,8 +1232,8 @@ def __unicode__(self): @classmethod def record(cls, event): - org = event.pop('org_id') - user = event.pop('user_id', None) + org_id = event.pop('org_id') + user_id = event.pop('user_id', None) action = event.pop('action') object_type = event.pop('object_type') object_id = event.pop('object_id', None) @@ -1206,9 +1241,11 @@ def record(cls, event): created_at = datetime.datetime.utcfromtimestamp(event.pop('timestamp')) additional_properties = json.dumps(event) - event = cls.create(org=org, user=user, action=action, object_type=object_type, object_id=object_id, - additional_properties=additional_properties, created_at=created_at) - + event = cls(org_id=org_id, user_id=user_id, action=action, + object_type=object_type, object_id=object_id, + additional_properties=additional_properties, + created_at=created_at) + db.session.add(event) return event class ApiKey(TimestampMixin, GFKBase, db.Model): @@ -1296,7 +1333,7 @@ class AlertSubscription(TimestampMixin, db.Model): nullable=True) destination = db.relationship(NotificationDestination) alert_id = Column(db.Integer, db.ForeignKey("alerts.id")) - alert = db.relationship(Alert, backref="subscriptions") + alert = db.relationship(Alert, back_populates="subscriptions") __tablename__ = 'alert_subscriptions' __table_args__ = (db.Index('alert_subscriptions_destination_id_alert_id', diff --git a/tests/factories.py b/tests/factories.py index 448d03e597..aebe017a5e 100644 --- a/tests/factories.py +++ b/tests/factories.py @@ -69,7 +69,7 @@ def __call__(self): user=user_factory.create, schedule=None, data_source=data_source_factory.create, - org=1) + org_id=1) query_with_params_factory = ModelFactory(redash.models.Query, name='New Query with Params', @@ -79,7 +79,7 @@ def __call__(self): is_archived=False, schedule=None, data_source=data_source_factory.create, - org=1) + org_id=1) access_permission_factory = ModelFactory(redash.models.AccessPermission, object_id=query_factory.create, @@ -101,7 +101,7 @@ def __call__(self): query="SELECT 1", query_hash=gen_query_hash('SELECT 1'), data_source=data_source_factory.create, - org=1) + org_id=1) visualization_factory = ModelFactory(redash.models.Visualization, type='CHART', @@ -118,7 +118,7 @@ def __call__(self): visualization=visualization_factory.create) destination_factory = ModelFactory(redash.models.NotificationDestination, - org=1, + org_id=1, user=user_factory.create, name='Destination', type='slack', @@ -231,21 +231,23 @@ def create_alert_subscription(self, **kwargs): return alert_subscription_factory.create(**args) def create_data_source(self, **kwargs): + group = None + if 'group' in kwargs: + group = kwargs.pop('group') args = { 'org': self.org } args.update(kwargs) - if 'group' in kwargs and 'org' not in kwargs: - args['org'] = kwargs['group'].org + if group and 'org' not in kwargs: + args['org'] = group.org data_source = data_source_factory.create(**args) - if 'group_id' in kwargs: + if group: view_only = kwargs.pop('view_only', False) - db.session.add(redash.models.DataSourceGroup( - group_id=kwargs['group_id'], + group=group, data_source=data_source, view_only=view_only)) @@ -292,7 +294,7 @@ def create_query_result(self, **kwargs): args.update(kwargs) if 'data_source' in args and 'org' not in args: - args['org'] = args['data_source'].org_id + args['org'] = args['data_source'].org return query_result_factory.create(**args) diff --git a/tests/tasks/test_refresh_queries.py b/tests/tasks/test_refresh_queries.py index 2ada084d7c..8927f9c928 100644 --- a/tests/tasks/test_refresh_queries.py +++ b/tests/tasks/test_refresh_queries.py @@ -98,7 +98,7 @@ def test_enqueues_only_for_relevant_data_source(self): query = self.factory.create_query(schedule="60") query2 = self.factory.create_query(schedule="3600", query=query.query, query_hash=query.query_hash) import psycopg2 - retrieved_at = utcnow().replace(tzinfo=psycopg2.tz.FixedOffsetTimezone(offset=0, name=None)) - datetime.timedelta(minutes=10) + retrieved_at = utcnow() - datetime.timedelta(minutes=10) query_result = self.factory.create_query_result(retrieved_at=retrieved_at, query=query.query, query_hash=query.query_hash) query.latest_query_data = query_result diff --git a/tests/test_models.py b/tests/test_models.py index fdf56b4743..ea9a33848a 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -40,8 +40,7 @@ def test_search_finds_in_name(self): q1 = self.factory.create_query(name=u"Testing seåřċħ") q2 = self.factory.create_query(name=u"Testing seåřċħing") q3 = self.factory.create_query(name=u"Testing seå řċħ") - db.session.flush() - queries = models.Query.search(u"seåřċħ", [self.factory.default_group]) + queries = list(models.Query.search(u"seåřċħ", [self.factory.default_group])) self.assertIn(q1, queries) self.assertIn(q2, queries) @@ -62,7 +61,7 @@ def test_search_by_id_returns_query(self): q1 = self.factory.create_query(description="Testing search") q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") - + db.session.flush() queries = models.Query.search(str(q3.id), [self.factory.default_group]) self.assertIn(q3, queries) @@ -70,25 +69,26 @@ def test_search_by_id_returns_query(self): self.assertNotIn(q2, queries) def test_search_respects_groups(self): - other_group = models.Group.create(org=self.factory.org, name="Other Group") + other_group = models.Group(org=self.factory.org, name="Other Group") + db.session.add(other_group) ds = self.factory.create_data_source(group=other_group) q1 = self.factory.create_query(description="Testing search", data_source=ds) q2 = self.factory.create_query(description="Testing searching") q3 = self.factory.create_query(description="Testing sea rch") - queries = models.Query.search("Testing", [self.factory.default_group]) + queries = list(models.Query.search("Testing", [self.factory.default_group])) self.assertNotIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = models.Query.search("Testing", [other_group, self.factory.default_group]) + queries = list(models.Query.search("Testing", [other_group, self.factory.default_group])) self.assertIn(q1, queries) self.assertIn(q2, queries) self.assertIn(q3, queries) - queries = models.Query.search("Testing", [other_group]) + queries = list(models.Query.search("Testing", [other_group])) self.assertIn(q1, queries) self.assertNotIn(q2, queries) self.assertNotIn(q3, queries) @@ -100,22 +100,23 @@ def test_returns_each_query_only_once(self): ds.add_group(second_group, False) q1 = self.factory.create_query(description="Testing search", data_source=ds) - + db.session.flush() queries = list(models.Query.search("Testing", [self.factory.default_group, other_group, second_group])) self.assertEqual(1, len(queries)) def test_save_creates_default_visualization(self): q = self.factory.create_query() - self.assertEquals(q.visualizations.count(), 1) + db.session.flush() + self.assertEquals(len(q.visualizations), 1) def test_save_updates_updated_at_field(self): # This should be a test of ModelTimestampsMixin, but it's easier to test in context of existing model... :-\ - one_day_ago = datetime.datetime.today() - datetime.timedelta(days=1) + one_day_ago = utcnow().date() - datetime.timedelta(days=1) q = self.factory.create_query(created_at=one_day_ago, updated_at=one_day_ago) - - q.save() - + db.session.flush() + q.name = 'x' + db.session.flush() self.assertNotEqual(q.updated_at, one_day_ago) @@ -123,22 +124,21 @@ class QueryRecentTest(BaseTestCase): def test_global_recent(self): q1 = self.factory.create_query() q2 = self.factory.create_query() - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - + db.session.flush() + e = models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + db.session.add(e) recent = models.Query.recent([self.factory.default_group]) - self.assertIn(q1, recent) self.assertNotIn(q2, recent) def test_recent_for_user(self): q1 = self.factory.create_query() q2 = self.factory.create_query() - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - + db.session.flush() + e = models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + db.session.add(e) recent = models.Query.recent([self.factory.default_group], user_id=self.factory.user.id) self.assertIn(q1, recent) @@ -152,11 +152,11 @@ def test_respects_groups(self): q1 = self.factory.create_query() ds = self.factory.create_data_source(group=self.factory.create_group()) q2 = self.factory.create_query(data_source=ds) - - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q1.id) - models.Event.create(org=self.factory.org, user=self.factory.user, action="edit", - object_type="query", object_id=q2.id) + db.session.flush() + models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q1.id) + models.Event(org=self.factory.org, user=self.factory.user, action="edit", + object_type="query", object_id=q2.id) recent = models.Query.recent([self.factory.default_group]) @@ -166,17 +166,17 @@ def test_respects_groups(self): class ShouldScheduleNextTest(TestCase): def test_interval_schedule_that_needs_reschedule(self): - now = datetime.datetime.now() + now = utcnow() two_hours_ago = now - datetime.timedelta(hours=2) self.assertTrue(models.should_schedule_next(two_hours_ago, now, "3600")) def test_interval_schedule_that_doesnt_need_reschedule(self): - now = datetime.datetime.now() + now = utcnow() half_an_hour_ago = now - datetime.timedelta(minutes=30) self.assertFalse(models.should_schedule_next(half_an_hour_ago, now, "3600")) def test_exact_time_that_needs_reschedule(self): - now = datetime.datetime.now() + now = utcnow() yesterday = now - datetime.timedelta(days=1) scheduled_datetime = now - datetime.timedelta(hours=3) scheduled_time = "{:02d}:00".format(scheduled_datetime.hour) @@ -189,7 +189,7 @@ def test_exact_time_that_doesnt_need_reschedule(self): self.assertFalse(models.should_schedule_next(yesterday, now, schedule)) def test_exact_time_with_day_change(self): - now = datetime.datetime.now().replace(hour=0, minute=1) + now = utcnow().replace(hour=0, minute=1) previous = (now - datetime.timedelta(days=2)).replace(hour=23, minute=59) schedule = "23:59".format(now.hour + 3) self.assertTrue(models.should_schedule_next(previous, now, schedule)) @@ -204,21 +204,19 @@ def test_outdated_queries_skips_unscheduled_queries(self): self.assertNotIn(query, queries) def test_outdated_queries_works_with_ttl_based_schedule(self): - two_hours_ago = datetime.datetime.now() - datetime.timedelta(hours=2) + two_hours_ago = utcnow() - datetime.timedelta(hours=2) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query, retrieved_at=two_hours_ago) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=two_hours_ago) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertIn(query, queries) def test_skips_fresh_queries(self): - half_an_hour_ago = datetime.datetime.now() - datetime.timedelta(minutes=30) + half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule="3600") - query_result = self.factory.create_query_result(query=query, retrieved_at=half_an_hour_ago) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertNotIn(query, queries) @@ -226,9 +224,8 @@ def test_skips_fresh_queries(self): def test_outdated_queries_works_with_specific_time_schedule(self): half_an_hour_ago = utcnow() - datetime.timedelta(minutes=30) query = self.factory.create_query(schedule=half_an_hour_ago.strftime('%H:%M')) - query_result = self.factory.create_query_result(query=query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) + query_result = self.factory.create_query_result(query=query.query, retrieved_at=half_an_hour_ago - datetime.timedelta(days=1)) query.latest_query_data = query_result - query.save() queries = models.Query.outdated_queries() self.assertIn(query, queries) @@ -240,53 +237,51 @@ def setUp(self): def test_archive_query_sets_flag(self): query = self.factory.create_query() + db.session.flush() query.archive() - query = models.Query.get_by_id(query.id) self.assertEquals(query.is_archived, True) def test_archived_query_doesnt_return_in_all(self): query = self.factory.create_query(schedule="1") - yesterday = datetime.datetime.now() - datetime.timedelta(days=1) - query_result, _ = models.QueryResult.store_result(query.org, query.data_source.id, query.query_hash, query.query, "1", - 123, yesterday) + yesterday = utcnow() - datetime.timedelta(days=1) + query_result, _ = models.QueryResult.store_result( + query.org, query.data_source, query.query_hash, query.query, + "1", 123, yesterday) query.latest_query_data = query_result - query.save() - - self.assertIn(query, list(models.Query.all_queries(query.groups.keys()))) + groups = list(models.Group.query.filter(models.Group.id.in_(query.groups))) + self.assertIn(query, list(models.Query.all_queries(groups))) self.assertIn(query, models.Query.outdated_queries()) - + db.session.flush() query.archive() - self.assertNotIn(query, list(models.Query.all_queries(query.groups.keys()))) + self.assertNotIn(query, list(models.Query.all_queries(groups))) self.assertNotIn(query, models.Query.outdated_queries()) def test_removes_associated_widgets_from_dashboards(self): widget = self.factory.create_widget() query = widget.visualization.query - + db.session.commit() query.archive() - - self.assertRaises(models.Widget.DoesNotExist, models.Widget.get_by_id, widget.id) + db.session.flush() + self.assertEqual(db.session.query(models.Widget).get(widget.id), None) def test_removes_scheduling(self): query = self.factory.create_query(schedule="1") query.archive() - query = models.Query.get_by_id(query.id) - self.assertEqual(None, query.schedule) def test_deletes_alerts(self): subscription = self.factory.create_alert_subscription() query = subscription.alert.query - + db.session.commit() query.archive() - - self.assertRaises(models.Alert.DoesNotExist, models.Alert.get_by_id, subscription.alert.id) - self.assertRaises(models.AlertSubscription.DoesNotExist, models.AlertSubscription.get_by_id, subscription.id) + db.session.flush() + self.assertEqual(db.session.query(models.Alert).get(subscription.alert.id), None) + self.assertEqual(db.session.query(models.AlertSubscription).get(subscription.id), None) class DataSourceTest(BaseTestCase): @@ -339,12 +334,6 @@ def test_get_latest_returns_when_found(self): self.assertEqual(qr, found_query_result) - def test_get_latest_works_with_data_source_id(self): - qr = self.factory.create_query_result() - found_query_result = models.QueryResult.get_latest(qr.data_source.id, qr.query, 60) - - self.assertEqual(qr, found_query_result) - def test_get_latest_doesnt_return_query_from_different_data_source(self): qr = self.factory.create_query_result() data_source = self.factory.create_data_source() @@ -353,7 +342,7 @@ def test_get_latest_doesnt_return_query_from_different_data_source(self): self.assertIsNone(found_query_result) def test_get_latest_doesnt_return_if_ttl_expired(self): - yesterday = datetime.datetime.now() - datetime.timedelta(days=1) + yesterday = utcnow() - datetime.timedelta(days=1) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=60) @@ -361,7 +350,7 @@ def test_get_latest_doesnt_return_if_ttl_expired(self): self.assertIsNone(found_query_result) def test_get_latest_returns_if_ttl_not_expired(self): - yesterday = datetime.datetime.now() - datetime.timedelta(seconds=30) + yesterday = utcnow() - datetime.timedelta(seconds=30) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, max_age=120) @@ -369,7 +358,7 @@ def test_get_latest_returns_if_ttl_not_expired(self): self.assertEqual(found_query_result, qr) def test_get_latest_returns_the_most_recent_result(self): - yesterday = datetime.datetime.now() - datetime.timedelta(seconds=30) + yesterday = utcnow() - datetime.timedelta(seconds=30) old_qr = self.factory.create_query_result(retrieved_at=yesterday) qr = self.factory.create_query_result() @@ -378,10 +367,10 @@ def test_get_latest_returns_the_most_recent_result(self): self.assertEqual(found_query_result.id, qr.id) def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): - yesterday = datetime.datetime.now() + datetime.timedelta(days=-100) + yesterday = utcnow() + datetime.timedelta(days=-100) very_old = self.factory.create_query_result(retrieved_at=yesterday) - yesterday = datetime.datetime.now() + datetime.timedelta(days=-1) + yesterday = utcnow() + datetime.timedelta(days=-1) qr = self.factory.create_query_result(retrieved_at=yesterday) found_query_result = models.QueryResult.get_latest(qr.data_source, qr.query, -1) @@ -390,7 +379,7 @@ def test_get_latest_returns_the_last_cached_result_for_negative_ttl(self): class TestUnusedQueryResults(BaseTestCase): def test_returns_only_unused_query_results(self): - two_weeks_ago = datetime.datetime.now() - datetime.timedelta(days=14) + two_weeks_ago = utcnow() - datetime.timedelta(days=14) qr = self.factory.create_query_result() query = self.factory.create_query(latest_query_data=qr) unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) @@ -399,7 +388,7 @@ def test_returns_only_unused_query_results(self): self.assertNotIn(qr, models.QueryResult.unused()) def test_returns_only_over_a_week_old_results(self): - two_weeks_ago = datetime.datetime.now() - datetime.timedelta(days=14) + two_weeks_ago = utcnow() - datetime.timedelta(days=14) unused_qr = self.factory.create_query_result(retrieved_at=two_weeks_ago) new_unused_qr = self.factory.create_query_result() @@ -412,15 +401,21 @@ def test_returns_only_queries_in_given_groups(self): ds1 = self.factory.create_data_source() ds2 = self.factory.create_data_source() - group1 = models.Group.create(name="g1", org=ds1.org) - group2 = models.Group.create(name="g2", org=ds1.org) - - models.DataSourceGroup.create(group=group1, data_source=ds1, permissions=['create', 'view']) - models.DataSourceGroup.create(group=group2, data_source=ds2, permissions=['create', 'view']) + group1 = models.Group(name="g1", org=ds1.org, permissions=['create', 'view']) + group2 = models.Group(name="g2", org=ds1.org, permissions=['create', 'view']) q1 = self.factory.create_query(data_source=ds1) q2 = self.factory.create_query(data_source=ds2) + db.session.add_all([ + ds1, ds2, + group1, group2, + q1, q2, + models.DataSourceGroup( + group=group1, data_source=ds1), + models.DataSourceGroup(group=group2, data_source=ds2) + ]) + db.session.flush() self.assertIn(q1, list(models.Query.all_queries([group1]))) self.assertNotIn(q2, list(models.Query.all_queries([group1]))) self.assertIn(q1, list(models.Query.all_queries([group1, group2]))) @@ -432,14 +427,14 @@ def test_default_group_always_added(self): user = self.factory.create_user() user.update_group_assignments(["g_unknown"]) - self.assertItemsEqual([user.org.default_group.id], user.groups) + self.assertItemsEqual([user.org.default_group.id], user.group_ids) def test_update_group_assignments(self): user = self.factory.user - new_group = models.Group.create(id='999', name="g1", org=user.org) + new_group = models.Group(id=999, name="g1", org=user.org) user.update_group_assignments(["g1"]) - self.assertItemsEqual([user.org.default_group.id, new_group.id], user.groups) + self.assertItemsEqual([user.org.default_group.id, new_group.id], user.group_ids) class TestGroup(BaseTestCase): @@ -447,9 +442,9 @@ def test_returns_groups_with_specified_names(self): org1 = self.factory.create_org() org2 = self.factory.create_org() - matching_group1 = models.Group.create(id='999', name="g1", org=org1) - matching_group2 = models.Group.create(id='888', name="g2", org=org1) - non_matching_group = models.Group.create(id='777', name="g1", org=org2) + matching_group1 = models.Group(id=999, name="g1", org=org1) + matching_group2 = models.Group(id=888, name="g2", org=org1) + non_matching_group = models.Group(id=777, name="g1", org=org2) groups = models.Group.find_by_name(org1, ["g1", "g2"]) self.assertIn(matching_group1, groups) @@ -459,7 +454,7 @@ def test_returns_groups_with_specified_names(self): def test_returns_no_groups(self): org1 = self.factory.create_org() - models.Group.create(id='999', name="g1", org=org1) + models.Group(id=999, name="g1", org=org1) self.assertEqual([], models.Group.find_by_name(org1, ["non-existing"])) @@ -474,9 +469,9 @@ def setUp(self): self.data = "data" def test_stores_the_result(self): - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, - self.data, self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) self.assertEqual(query_result.data, self.data) self.assertEqual(query_result.runtime, self.runtime) @@ -490,39 +485,39 @@ def test_updates_existing_queries(self): query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query) - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_hash(self): query1 = self.factory.create_query(query=self.query) query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query + "123") - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertNotEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertNotEqual(query3.latest_query_data, query_result) def test_doesnt_update_queries_with_different_data_source(self): query1 = self.factory.create_query(query=self.query) query2 = self.factory.create_query(query=self.query) query3 = self.factory.create_query(query=self.query, data_source=self.factory.create_data_source()) - query_result, _ = models.QueryResult.store_result(self.data_source.org_id, self.data_source.id, self.query_hash, - self.query, self.data, - self.runtime, self.utcnow) + query_result, _ = models.QueryResult.store_result( + self.data_source.org, self.data_source, self.query_hash, + self.query, self.data, self.runtime, self.utcnow) - self.assertEqual(models.Query.get_by_id(query1.id)._data['latest_query_data'], query_result.id) - self.assertEqual(models.Query.get_by_id(query2.id)._data['latest_query_data'], query_result.id) - self.assertNotEqual(models.Query.get_by_id(query3.id)._data['latest_query_data'], query_result.id) + self.assertEqual(query1.latest_query_data, query_result) + self.assertEqual(query2.latest_query_data, query_result) + self.assertNotEqual(query3.latest_query_data, query_result) class TestEvents(BaseTestCase): @@ -530,6 +525,7 @@ def raw_event(self): timestamp = 1411778709.791 user = self.factory.user created_at = datetime.datetime.utcfromtimestamp(timestamp) + db.session.flush() raw_event = {"action": "view", "timestamp": timestamp, "object_type": "dashboard", @@ -543,7 +539,7 @@ def test_records_event(self): raw_event, user, created_at = self.raw_event() event = models.Event.record(raw_event) - + db.session.flush() self.assertEqual(event.user, user) self.assertEqual(event.action, "view") self.assertEqual(event.object_type, "dashboard") @@ -564,32 +560,36 @@ class TestWidgetDeleteInstance(BaseTestCase): def test_delete_removes_from_layout(self): widget = self.factory.create_widget() widget2 = self.factory.create_widget(dashboard=widget.dashboard) + db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) - widget.dashboard.save() - widget.delete_instance() - + db.session.delete(widget) + db.session.flush() self.assertEquals(json.dumps([[widget2.id]]), widget.dashboard.layout) def test_delete_removes_empty_rows(self): widget = self.factory.create_widget() widget2 = self.factory.create_widget(dashboard=widget.dashboard) + db.session.flush() widget.dashboard.layout = json.dumps([[widget.id, widget2.id]]) - widget.dashboard.save() - widget.delete_instance() - widget2.delete_instance() - + db.session.flush() + db.session.delete(widget) + db.session.delete(widget2) + db.session.flush() self.assertEquals("[]", widget.dashboard.layout) def _set_up_dashboard_test(d): - d.g1 = d.factory.create_group(name='First') - d.g2 = d.factory.create_group(name='Second') + d.g1 = d.factory.create_group(name='First', permissions=['create', 'view']) + d.g2 = d.factory.create_group(name='Second', permissions=['create', 'view']) d.ds1 = d.factory.create_data_source() d.ds2 = d.factory.create_data_source() + db.session.flush() d.u1 = d.factory.create_user(group_ids=[d.g1.id]) d.u2 = d.factory.create_user(group_ids=[d.g2.id]) - models.DataSourceGroup.create(group=d.g1, data_source=d.ds1, permissions=['create', 'view']) - models.DataSourceGroup.create(group=d.g2, data_source=d.ds2, permissions=['create', 'view']) + db.session.add_all([ + models.DataSourceGroup(group=d.g1, data_source=d.ds1), + models.DataSourceGroup(group=d.g2, data_source=d.ds2) + ]) d.q1 = d.factory.create_query(data_source=d.ds1) d.q2 = d.factory.create_query(data_source=d.ds2) d.v1 = d.factory.create_visualization(query=d.q1) @@ -608,43 +608,49 @@ def setUp(self): def test_requires_group_or_user_id(self): d1 = self.factory.create_dashboard() - - self.assertNotIn(d1, models.Dashboard.all(d1.user.org, d1.user.groups, None)) - self.assertIn(d1, models.Dashboard.all(d1.user.org, [0], d1.user.id)) + self.assertNotIn(d1, list(models.Dashboard.all( + d1.user.org, d1.user.group_ids, None))) + l2 = list(models.Dashboard.all( + d1.user.org, [0], d1.user.id)) + self.assertIn(d1, l2) def test_returns_dashboards_based_on_groups(self): - self.assertIn(self.w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertIn(self.w2.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) + self.assertIn(self.w1.dashboard, list(models.Dashboard.all( + self.u1.org, self.u1.group_ids, None))) + self.assertIn(self.w2.dashboard, list(models.Dashboard.all( + self.u2.org, self.u2.group_ids, None))) + self.assertNotIn(self.w1.dashboard, list(models.Dashboard.all( + self.u2.org, self.u2.group_ids, None))) + self.assertNotIn(self.w2.dashboard, list(models.Dashboard.all( + self.u1.org, self.u1.group_ids, None))) def test_returns_each_dashboard_once(self): - dashboards = list(models.Dashboard.all(self.u2.org, self.u2.groups, None)) + dashboards = list(models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) self.assertEqual(len(dashboards), 2) def test_returns_dashboard_you_have_partial_access_to(self): - self.assertIn(self.w5.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) + self.assertIn(self.w5.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) def test_returns_dashboards_created_by_user(self): d1 = self.factory.create_dashboard(user=self.u1) - - self.assertIn(d1, models.Dashboard.all(self.u1.org, self.u1.groups, self.u1.id)) - self.assertIn(d1, models.Dashboard.all(self.u1.org, [0], self.u1.id)) - self.assertNotIn(d1, models.Dashboard.all(self.u2.org, self.u2.groups, self.u2.id)) + db.session.flush() + self.assertIn(d1, list(models.Dashboard.all(self.u1.org, self.u1.group_ids, self.u1.id))) + self.assertIn(d1, list(models.Dashboard.all(self.u1.org, [0], self.u1.id))) + self.assertNotIn(d1, list(models.Dashboard.all(self.u2.org, self.u2.group_ids, self.u2.id))) def test_returns_dashboards_with_text_widgets(self): w1 = self.factory.create_widget(visualization=None) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u2.org, self.u2.group_ids, None)) def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) user = self.factory.create_user(org=self.factory.create_org()) - self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.groups, None)) - self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.all(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(w1.dashboard, models.Dashboard.all(user.org, user.group_ids, None)) class TestDashboardRecent(BaseTestCase): @@ -653,62 +659,71 @@ def setUp(self): _set_up_dashboard_test(self) def test_returns_recent_dashboards_basic(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - - self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u2.groups, None)) + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id)) + self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u2.group_ids, None)) def test_returns_recent_dashboards_created_by_user(self): d1 = self.factory.create_dashboard(user=self.u1) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=d1.id) - + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=d1.id)) self.assertIn(d1, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(d1, models.Dashboard.recent(self.u2.org, [0], self.u2.id)) def test_returns_recent_dashboards_with_no_visualizations(self): w1 = self.factory.create_widget(visualization=None) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=w1.dashboard.id) - + db.session.flush() + db.session.add(models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=w1.dashboard.id)) + db.session.flush() self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, [0], self.u1.id)) def test_restricts_dashboards_for_user(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u2, action="view", - object_type="dashboard", object_id=self.w2.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w5.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u2, action="view", - object_type="dashboard", object_id=self.w5.dashboard.id) - - self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w2.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) - self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) - self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, self.u1.id, for_user=True)) - self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u2.org, self.u2.groups, self.u2.id, for_user=True)) + db.session.flush() + db.session.add_all([ + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id), + models.Event(org=self.factory.org, user=self.u2, action="view", + object_type="dashboard", object_id=self.w2.dashboard.id), + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w5.dashboard.id), + models.Event(org=self.factory.org, user=self.u2, action="view", + object_type="dashboard", object_id=self.w5.dashboard.id) + ]) + db.session.flush() + self.assertIn(self.w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w2.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) + self.assertNotIn(self.w1.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) + self.assertNotIn(self.w2.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, self.u1.id, for_user=True)) + self.assertIn(self.w5.dashboard, models.Dashboard.recent(self.u2.org, self.u2.group_ids, self.u2.id, for_user=True)) def test_returns_each_dashboard_once(self): - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=self.w1.dashboard.id) - - dashboards = list(models.Dashboard.recent(self.u1.org, self.u1.groups, None)) + db.session.flush() + db.session.add_all([ + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id), + models.Event(org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=self.w1.dashboard.id) + ]) + db.session.flush() + dashboards = list(models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) self.assertEqual(len(dashboards), 1) def test_returns_dashboards_from_current_org_only(self): w1 = self.factory.create_widget(visualization=None) - models.Event.create(org=self.factory.org, user=self.u1, action="view", - object_type="dashboard", object_id=w1.dashboard.id) - + db.session.flush() + db.session.add(models.Event( + org=self.factory.org, user=self.u1, action="view", + object_type="dashboard", object_id=w1.dashboard.id)) + db.session.flush() user = self.factory.create_user(org=self.factory.create_org()) - self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.groups, None)) - self.assertNotIn(w1.dashboard, models.Dashboard.recent(user.org, user.groups, None)) + self.assertIn(w1.dashboard, models.Dashboard.recent(self.u1.org, self.u1.group_ids, None)) + self.assertNotIn(w1.dashboard, models.Dashboard.recent(user.org, user.group_ids, None))