Skip to content

Commit

Permalink
fix(reports): make name unique between alerts and reports (#12196)
Browse files Browse the repository at this point in the history
* fix(reports): make name unique between alerts and reports

* add missing migration

* make it work for mySQL and PG only (yet)

* fixing sqlite crazy unique drop

* fixing sqlite missing one col
  • Loading branch information
dpgaspar authored Dec 24, 2020
1 parent b75a1ec commit 74f3faf
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 14 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""alert reports shared uniqueness
Revision ID: c878781977c6
Revises: 73fd22e742ab
Create Date: 2020-12-23 11:34:53.882200
"""

# revision identifiers, used by Alembic.
revision = "c878781977c6"
down_revision = "73fd22e742ab"

import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.mysql.base import MySQLDialect
from sqlalchemy.dialects.postgresql.base import PGDialect
from sqlalchemy.dialects.sqlite.base import SQLiteDialect
from sqlalchemy.engine.reflection import Inspector

from superset.utils.core import generic_find_uq_constraint_name

report_schedule = sa.Table(
"report_schedule",
sa.MetaData(),
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("type", sa.String(length=50), nullable=False),
sa.Column("name", sa.String(length=150), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("context_markdown", sa.Text(), nullable=True),
sa.Column("active", sa.Boolean(), default=True, nullable=True),
sa.Column("crontab", sa.String(length=1000), nullable=False),
sa.Column("sql", sa.Text(), nullable=True),
sa.Column("chart_id", sa.Integer(), nullable=True),
sa.Column("dashboard_id", sa.Integer(), nullable=True),
sa.Column("database_id", sa.Integer(), nullable=True),
sa.Column("last_eval_dttm", sa.DateTime(), nullable=True),
sa.Column("last_state", sa.String(length=50), nullable=True),
sa.Column("last_value", sa.Float(), nullable=True),
sa.Column("last_value_row_json", sa.Text(), nullable=True),
sa.Column("validator_type", sa.String(length=100), nullable=True),
sa.Column("validator_config_json", sa.Text(), default="{}", nullable=True),
sa.Column("log_retention", sa.Integer(), nullable=True, default=90),
sa.Column("grace_period", sa.Integer(), nullable=True, default=60 * 60 * 4),
sa.Column("working_timeout", sa.Integer(), nullable=True, default=60 * 60 * 1),
# Audit Mixin
sa.Column("created_on", sa.DateTime(), nullable=True),
sa.Column("changed_on", sa.DateTime(), nullable=True),
sa.Column("created_by_fk", sa.Integer(), nullable=True),
sa.Column("changed_by_fk", sa.Integer(), nullable=True),
sa.ForeignKeyConstraint(["chart_id"], ["slices.id"]),
sa.ForeignKeyConstraint(["dashboard_id"], ["dashboards.id"]),
sa.ForeignKeyConstraint(["database_id"], ["dbs.id"]),
sa.ForeignKeyConstraint(["changed_by_fk"], ["ab_user.id"]),
sa.ForeignKeyConstraint(["created_by_fk"], ["ab_user.id"]),
sa.PrimaryKeyConstraint("id"),
)


def upgrade():
bind = op.get_bind()

if not isinstance(bind.dialect, SQLiteDialect):
op.drop_constraint("uq_report_schedule_name", "report_schedule", type_="unique")

if isinstance(bind.dialect, MySQLDialect):
op.drop_index(
op.f("name"), table_name="report_schedule",
)

if isinstance(bind.dialect, PGDialect):
op.drop_constraint(
"report_schedule_name_key", "report_schedule", type_="unique"
)
op.create_unique_constraint(
"uq_report_schedule_name_type", "report_schedule", ["name", "type"]
)

else:
with op.batch_alter_table(
"report_schedule", copy_from=report_schedule
) as batch_op:
batch_op.create_unique_constraint(
"uq_report_schedule_name_type", ["name", "type"]
)


def downgrade():
pass
4 changes: 3 additions & 1 deletion superset/models/reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,11 @@ class ReportSchedule(Model, AuditMixinNullable):
"""

__tablename__ = "report_schedule"
__table_args__ = (UniqueConstraint("name", "type"),)

id = Column(Integer, primary_key=True)
type = Column(String(50), nullable=False)
name = Column(String(150), nullable=False, unique=True)
name = Column(String(150), nullable=False)
description = Column(Text)
context_markdown = Column(Text)
active = Column(Boolean, default=True, index=True)
Expand Down
6 changes: 4 additions & 2 deletions superset/reports/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,10 @@ def validate(self) -> None:
if not report_type:
exceptions.append(ReportScheduleRequiredTypeValidationError())

# Validate name uniqueness
if not ReportScheduleDAO.validate_update_uniqueness(name):
# Validate name type uniqueness
if report_type and not ReportScheduleDAO.validate_update_uniqueness(
name, report_type
):
exceptions.append(ReportScheduleNameUniquenessValidationError())

# validate relation by report type
Expand Down
2 changes: 1 addition & 1 deletion superset/reports/commands/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class ReportScheduleWorkingTimeoutError(CommandException):

class ReportScheduleNameUniquenessValidationError(ValidationError):
"""
Marshmallow validation error for Report Schedule name already exists
Marshmallow validation error for Report Schedule name and type already exists
"""

def __init__(self) -> None:
Expand Down
11 changes: 6 additions & 5 deletions superset/reports/commands/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,15 +80,16 @@ def validate(self) -> None:
):
self._properties["last_state"] = ReportState.NOOP

# Validate name uniqueness
# validate relation by report type
if not report_type:
report_type = self._model.type

# Validate name type uniqueness
if not ReportScheduleDAO.validate_update_uniqueness(
name, report_schedule_id=self._model_id
name, report_type, report_schedule_id=self._model_id
):
exceptions.append(ReportScheduleNameUniquenessValidationError())

# validate relation by report type
if not report_type:
report_type = self._model.type
if report_type == ReportScheduleType.ALERT:
database_id = self._properties.get("database")
# If database_id was sent let's validate it exists
Expand Down
9 changes: 6 additions & 3 deletions superset/reports/dao.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,20 @@ def bulk_delete(

@staticmethod
def validate_update_uniqueness(
name: str, report_schedule_id: Optional[int] = None
name: str, report_type: str, report_schedule_id: Optional[int] = None
) -> bool:
"""
Validate if this name is unique.
Validate if this name and type is unique.
:param name: The report schedule name
:param report_type: The report schedule type
:param report_schedule_id: The report schedule current id
(only for validating on updates)
:return: bool
"""
query = db.session.query(ReportSchedule).filter(ReportSchedule.name == name)
query = db.session.query(ReportSchedule).filter(
ReportSchedule.name == name, ReportSchedule.type == report_type
)
if report_schedule_id:
query = query.filter(ReportSchedule.id != report_schedule_id)
return not db.session.query(query.exists()).scalar()
Expand Down
14 changes: 12 additions & 2 deletions superset/reports/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Union
from typing import Any, Dict, Union

from croniter import croniter
from marshmallow import fields, Schema, validate
from marshmallow import fields, Schema, validate, validates_schema
from marshmallow.validate import Length, ValidationError

from superset.models.reports import (
Expand Down Expand Up @@ -170,6 +170,16 @@ class ReportSchedulePostSchema(Schema):

recipients = fields.List(fields.Nested(ReportRecipientSchema))

@validates_schema
def validate_report_references( # pylint: disable=unused-argument,no-self-use
self, data: Dict[str, Any], **kwargs: Any
) -> None:
if data["type"] == ReportScheduleType.REPORT:
if "database" in data:
raise ValidationError(
{"database": ["Database reference is not allowed on a report"]}
)


class ReportSchedulePutSchema(Schema):
type = fields.String(
Expand Down
40 changes: 40 additions & 0 deletions tests/reports/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,46 @@ def test_create_report_schedule_uniqueness(self):
data = json.loads(rv.data.decode("utf-8"))
assert data == {"message": {"name": ["Name must be unique"]}}

# Check that uniqueness is composed by name and type
report_schedule_data = {
"type": ReportScheduleType.REPORT,
"name": "name3",
"description": "description",
"crontab": "0 9 * * *",
"chart": chart.id,
}
uri = "api/v1/report/"
rv = self.client.post(uri, json=report_schedule_data)
assert rv.status_code == 201
data = json.loads(rv.data.decode("utf-8"))

# Rollback changes
created_model = db.session.query(ReportSchedule).get(data.get("id"))
db.session.delete(created_model)
db.session.commit()

@pytest.mark.usefixtures("create_report_schedules")
def test_create_report_schedule_schema(self):
"""
ReportSchedule Api: Test create report schedule schema check
"""
self.login(username="admin")
chart = db.session.query(Slice).first()
example_db = get_example_database()

# Check that a report does not have a database reference
report_schedule_data = {
"type": ReportScheduleType.REPORT,
"name": "name3",
"description": "description",
"crontab": "0 9 * * *",
"chart": chart.id,
"database": example_db.id,
}
uri = "api/v1/report/"
rv = self.client.post(uri, json=report_schedule_data)
assert rv.status_code == 400

@pytest.mark.usefixtures("create_report_schedules")
def test_create_report_schedule_chart_dash_validation(self):
"""
Expand Down

0 comments on commit 74f3faf

Please sign in to comment.