diff --git a/superset/dashboards/permalink/commands/base.py b/superset/dashboards/permalink/commands/base.py index f4dc4f0726110..82e24264ca920 100644 --- a/superset/dashboards/permalink/commands/base.py +++ b/superset/dashboards/permalink/commands/base.py @@ -18,11 +18,12 @@ from superset.commands.base import BaseCommand from superset.key_value.shared_entries import get_permalink_salt -from superset.key_value.types import KeyValueResource, SharedKey +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey class BaseDashboardPermalinkCommand(BaseCommand, ABC): resource = KeyValueResource.DASHBOARD_PERMALINK + codec = JsonKeyValueCodec() @property def salt(self) -> str: diff --git a/superset/dashboards/permalink/commands/create.py b/superset/dashboards/permalink/commands/create.py index 9569f83919f05..9f53208da69af 100644 --- a/superset/dashboards/permalink/commands/create.py +++ b/superset/dashboards/permalink/commands/create.py @@ -58,6 +58,7 @@ def run(self) -> str: resource=self.resource, key=get_deterministic_uuid(self.salt, (user_id, value)), value=value, + codec=self.codec, ).run() assert key.id # for type checks return encode_permalink_key(key=key.id, salt=self.salt) diff --git a/superset/dashboards/permalink/commands/get.py b/superset/dashboards/permalink/commands/get.py index f89f9444e7a4e..4206263a37fe5 100644 --- a/superset/dashboards/permalink/commands/get.py +++ b/superset/dashboards/permalink/commands/get.py @@ -39,7 +39,11 @@ def run(self) -> Optional[DashboardPermalinkValue]: self.validate() try: key = decode_permalink_id(self.key, salt=self.salt) - command = GetKeyValueCommand(resource=self.resource, key=key) + command = GetKeyValueCommand( + resource=self.resource, + key=key, + codec=self.codec, + ) value: Optional[DashboardPermalinkValue] = command.run() if value: DashboardDAO.get_by_id_or_slug(value["dashboardId"]) diff --git a/superset/explore/permalink/commands/base.py b/superset/explore/permalink/commands/base.py index bef9546e21686..a87183b7e9ed3 100644 --- a/superset/explore/permalink/commands/base.py +++ b/superset/explore/permalink/commands/base.py @@ -18,11 +18,12 @@ from superset.commands.base import BaseCommand from superset.key_value.shared_entries import get_permalink_salt -from superset.key_value.types import KeyValueResource, SharedKey +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey class BaseExplorePermalinkCommand(BaseCommand, ABC): resource: KeyValueResource = KeyValueResource.EXPLORE_PERMALINK + codec = JsonKeyValueCodec() @property def salt(self) -> str: diff --git a/superset/explore/permalink/commands/create.py b/superset/explore/permalink/commands/create.py index 77ce04c4e47b0..fb02ec8ca8201 100644 --- a/superset/explore/permalink/commands/create.py +++ b/superset/explore/permalink/commands/create.py @@ -52,6 +52,7 @@ def run(self) -> str: command = CreateKeyValueCommand( resource=self.resource, value=value, + codec=self.codec, ) key = command.run() if key.id is None: diff --git a/superset/explore/permalink/commands/get.py b/superset/explore/permalink/commands/get.py index 3376cab080962..4823117ecef53 100644 --- a/superset/explore/permalink/commands/get.py +++ b/superset/explore/permalink/commands/get.py @@ -43,6 +43,7 @@ def run(self) -> Optional[ExplorePermalinkValue]: value: Optional[ExplorePermalinkValue] = GetKeyValueCommand( resource=self.resource, key=key, + codec=self.codec, ).run() if value: chart_id: Optional[int] = value.get("chartId") diff --git a/superset/extensions/metastore_cache.py b/superset/extensions/metastore_cache.py index 1e5cff7ee3ccf..f69276c908430 100644 --- a/superset/extensions/metastore_cache.py +++ b/superset/extensions/metastore_cache.py @@ -23,10 +23,11 @@ from flask_caching import BaseCache from superset.key_value.exceptions import KeyValueCreateFailedError -from superset.key_value.types import KeyValueResource +from superset.key_value.types import KeyValueResource, PickleKeyValueCodec from superset.key_value.utils import get_uuid_namespace RESOURCE = KeyValueResource.METASTORE_CACHE +CODEC = PickleKeyValueCodec() class SupersetMetastoreCache(BaseCache): @@ -68,6 +69,7 @@ def set(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: resource=RESOURCE, key=self.get_key(key), value=value, + codec=CODEC, expires_on=self._get_expiry(timeout), ).run() return True @@ -80,6 +82,7 @@ def add(self, key: str, value: Any, timeout: Optional[int] = None) -> bool: CreateKeyValueCommand( resource=RESOURCE, value=value, + codec=CODEC, key=self.get_key(key), expires_on=self._get_expiry(timeout), ).run() @@ -92,7 +95,11 @@ def get(self, key: str) -> Any: # pylint: disable=import-outside-toplevel from superset.key_value.commands.get import GetKeyValueCommand - return GetKeyValueCommand(resource=RESOURCE, key=self.get_key(key)).run() + return GetKeyValueCommand( + resource=RESOURCE, + key=self.get_key(key), + codec=CODEC, + ).run() def has(self, key: str) -> bool: entry = self.get(key) diff --git a/superset/key_value/commands/create.py b/superset/key_value/commands/create.py index 93e99c223ba59..d66d99d6e9702 100644 --- a/superset/key_value/commands/create.py +++ b/superset/key_value/commands/create.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import logging -import pickle from datetime import datetime from typing import Any, Optional, Union from uuid import UUID @@ -26,7 +25,7 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueCreateFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueResource +from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.utils.core import get_user_id logger = logging.getLogger(__name__) @@ -35,13 +34,15 @@ class CreateKeyValueCommand(BaseCommand): resource: KeyValueResource value: Any + codec: KeyValueCodec key: Optional[Union[int, UUID]] expires_on: Optional[datetime] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, resource: KeyValueResource, value: Any, + codec: KeyValueCodec, key: Optional[Union[int, UUID]] = None, expires_on: Optional[datetime] = None, ): @@ -50,16 +51,24 @@ def __init__( :param resource: the resource (dashboard, chart etc) :param value: the value to persist in the key-value store + :param codec: codec used to encode the value :param key: id of entry (autogenerated if undefined) :param expires_on: entry expiration time - :return: the key associated with the persisted value + : """ self.resource = resource self.value = value + self.codec = codec self.key = key self.expires_on = expires_on def run(self) -> Key: + """ + Persist the value + + :return: the key associated with the persisted value + + """ try: return self.create() except SQLAlchemyError as ex: @@ -70,9 +79,13 @@ def validate(self) -> None: pass def create(self) -> Key: + try: + value = self.codec.encode(self.value) + except Exception as ex: # pylint: disable=broad-except + raise KeyValueCreateFailedError("Unable to encode value") from ex entry = KeyValueEntry( resource=self.resource.value, - value=pickle.dumps(self.value), + value=value, created_on=datetime.now(), created_by_fk=get_user_id(), expires_on=self.expires_on, diff --git a/superset/key_value/commands/get.py b/superset/key_value/commands/get.py index 44c02331cccb9..9d659f3bc7c06 100644 --- a/superset/key_value/commands/get.py +++ b/superset/key_value/commands/get.py @@ -16,7 +16,6 @@ # under the License. import logging -import pickle from datetime import datetime from typing import Any, Optional, Union from uuid import UUID @@ -27,7 +26,7 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueGetFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyValueResource +from superset.key_value.types import KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter logger = logging.getLogger(__name__) @@ -36,17 +35,25 @@ class GetKeyValueCommand(BaseCommand): resource: KeyValueResource key: Union[int, UUID] + codec: KeyValueCodec - def __init__(self, resource: KeyValueResource, key: Union[int, UUID]): + def __init__( + self, + resource: KeyValueResource, + key: Union[int, UUID], + codec: KeyValueCodec, + ): """ Retrieve a key value entry :param resource: the resource (dashboard, chart etc) :param key: the key to retrieve + :param codec: codec used to decode the value :return: the value associated with the key if present """ self.resource = resource self.key = key + self.codec = codec def run(self) -> Any: try: @@ -66,5 +73,5 @@ def get(self) -> Optional[Any]: .first() ) if entry and (entry.expires_on is None or entry.expires_on > datetime.now()): - return pickle.loads(entry.value) + return self.codec.decode(entry.value) return None diff --git a/superset/key_value/commands/update.py b/superset/key_value/commands/update.py index b69ca5e70d76b..becd6d9ca8d01 100644 --- a/superset/key_value/commands/update.py +++ b/superset/key_value/commands/update.py @@ -16,7 +16,6 @@ # under the License. import logging -import pickle from datetime import datetime from typing import Any, Optional, Union from uuid import UUID @@ -27,7 +26,7 @@ from superset.commands.base import BaseCommand from superset.key_value.exceptions import KeyValueUpdateFailedError from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueResource +from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter from superset.utils.core import get_user_id @@ -37,14 +36,16 @@ class UpdateKeyValueCommand(BaseCommand): resource: KeyValueResource value: Any + codec: KeyValueCodec key: Union[int, UUID] expires_on: Optional[datetime] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, resource: KeyValueResource, key: Union[int, UUID], value: Any, + codec: KeyValueCodec, expires_on: Optional[datetime] = None, ): """ @@ -53,12 +54,14 @@ def __init__( :param resource: the resource (dashboard, chart etc) :param key: the key to update :param value: the value to persist in the key-value store + :param codec: codec used to encode the value :param expires_on: entry expiration time :return: the key associated with the updated value """ self.resource = resource self.key = key self.value = value + self.codec = codec self.expires_on = expires_on def run(self) -> Optional[Key]: @@ -80,7 +83,7 @@ def update(self) -> Optional[Key]: .first() ) if entry: - entry.value = pickle.dumps(self.value) + entry.value = self.codec.encode(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() diff --git a/superset/key_value/commands/upsert.py b/superset/key_value/commands/upsert.py index 06b33c90fcfec..c5668f11610ab 100644 --- a/superset/key_value/commands/upsert.py +++ b/superset/key_value/commands/upsert.py @@ -16,7 +16,6 @@ # under the License. import logging -import pickle from datetime import datetime from typing import Any, Optional, Union from uuid import UUID @@ -31,7 +30,7 @@ KeyValueUpsertFailedError, ) from superset.key_value.models import KeyValueEntry -from superset.key_value.types import Key, KeyValueResource +from superset.key_value.types import Key, KeyValueCodec, KeyValueResource from superset.key_value.utils import get_filter from superset.utils.core import get_user_id @@ -42,13 +41,15 @@ class UpsertKeyValueCommand(BaseCommand): resource: KeyValueResource value: Any key: Union[int, UUID] + codec: KeyValueCodec expires_on: Optional[datetime] - def __init__( + def __init__( # pylint: disable=too-many-arguments self, resource: KeyValueResource, key: Union[int, UUID], value: Any, + codec: KeyValueCodec, expires_on: Optional[datetime] = None, ): """ @@ -57,13 +58,14 @@ def __init__( :param resource: the resource (dashboard, chart etc) :param key: the key to update :param value: the value to persist in the key-value store - :param key_type: the type of the key to update + :param codec: codec used to encode the value :param expires_on: entry expiration time :return: the key associated with the updated value """ self.resource = resource self.key = key self.value = value + self.codec = codec self.expires_on = expires_on def run(self) -> Key: @@ -85,7 +87,7 @@ def upsert(self) -> Key: .first() ) if entry: - entry.value = pickle.dumps(self.value) + entry.value = self.codec.encode(self.value) entry.expires_on = self.expires_on entry.changed_on = datetime.now() entry.changed_by_fk = get_user_id() @@ -96,6 +98,7 @@ def upsert(self) -> Key: return CreateKeyValueCommand( resource=self.resource, value=self.value, + codec=self.codec, key=self.key, expires_on=self.expires_on, ).run() diff --git a/superset/key_value/shared_entries.py b/superset/key_value/shared_entries.py index 5f4ded949808c..7895b759079ef 100644 --- a/superset/key_value/shared_entries.py +++ b/superset/key_value/shared_entries.py @@ -18,11 +18,12 @@ from typing import Any, Optional from uuid import uuid3 -from superset.key_value.types import KeyValueResource, SharedKey +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource, SharedKey from superset.key_value.utils import get_uuid_namespace, random_key RESOURCE = KeyValueResource.APP NAMESPACE = get_uuid_namespace("") +CODEC = JsonKeyValueCodec() def get_shared_value(key: SharedKey) -> Optional[Any]: @@ -30,7 +31,7 @@ def get_shared_value(key: SharedKey) -> Optional[Any]: from superset.key_value.commands.get import GetKeyValueCommand uuid_key = uuid3(NAMESPACE, key) - return GetKeyValueCommand(RESOURCE, key=uuid_key).run() + return GetKeyValueCommand(RESOURCE, key=uuid_key, codec=CODEC).run() def set_shared_value(key: SharedKey, value: Any) -> None: @@ -38,7 +39,12 @@ def set_shared_value(key: SharedKey, value: Any) -> None: from superset.key_value.commands.create import CreateKeyValueCommand uuid_key = uuid3(NAMESPACE, key) - CreateKeyValueCommand(resource=RESOURCE, value=value, key=uuid_key).run() + CreateKeyValueCommand( + resource=RESOURCE, + value=value, + key=uuid_key, + codec=CODEC, + ).run() def get_permalink_salt(key: SharedKey) -> str: diff --git a/superset/key_value/types.py b/superset/key_value/types.py index c3064fbef4d42..07d06414f60ea 100644 --- a/superset/key_value/types.py +++ b/superset/key_value/types.py @@ -14,9 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + +import json +import pickle +from abc import ABC, abstractmethod from dataclasses import dataclass from enum import Enum -from typing import Optional, TypedDict +from typing import Any, Optional, TypedDict from uuid import UUID @@ -42,3 +47,29 @@ class KeyValueResource(str, Enum): class SharedKey(str, Enum): DASHBOARD_PERMALINK_SALT = "dashboard_permalink_salt" EXPLORE_PERMALINK_SALT = "explore_permalink_salt" + + +class KeyValueCodec(ABC): + @abstractmethod + def encode(self, value: Any) -> bytes: + ... + + @abstractmethod + def decode(self, value: bytes) -> Any: + ... + + +class JsonKeyValueCodec(KeyValueCodec): + def encode(self, value: dict[Any, Any]) -> bytes: + return bytes(json.dumps(value), encoding="utf-8") + + def decode(self, value: bytes) -> dict[Any, Any]: + return json.loads(value) + + +class PickleKeyValueCodec(KeyValueCodec): + def encode(self, value: dict[Any, Any]) -> bytes: + return pickle.dumps(value) + + def decode(self, value: bytes) -> dict[Any, Any]: + return pickle.loads(value) diff --git a/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py b/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py new file mode 100644 index 0000000000000..57931ff821e7a --- /dev/null +++ b/superset/migrations/versions/2023-05-01_12-03_9c2a5681ddfd_convert_key_value_entries_to_json.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""convert key-value entries to json + +Revision ID: 9c2a5681ddfd +Revises: 7e67aecbf3f1 +Create Date: 2023-05-01 12:03:17.079862 + +""" + +# revision identifiers, used by Alembic. +revision = "9c2a5681ddfd" +down_revision = "7e67aecbf3f1" + +import io +import json +import pickle + +from alembic import op +from sqlalchemy import Column, Integer, LargeBinary, String +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import Session + +from superset import db +from superset.migrations.shared.utils import paginated_update + +Base = declarative_base() +VALUE_MAX_SIZE = 2**24 - 1 +RESOURCES_TO_MIGRATE = ("app", "dashboard_permalink", "explore_permalink") + + +class RestrictedUnpickler(pickle.Unpickler): + def find_class(self, module, name): + raise pickle.UnpicklingError(f"Unpickling of {module}.{name} is forbidden") + + +class KeyValueEntry(Base): + __tablename__ = "key_value" + id = Column(Integer, primary_key=True) + resource = Column(String(32), nullable=False) + value = Column(LargeBinary(length=VALUE_MAX_SIZE), nullable=False) + + +def upgrade(): + bind = op.get_bind() + session: Session = db.Session(bind=bind) + for entry in paginated_update( + session.query(KeyValueEntry).filter( + KeyValueEntry.resource.in_(RESOURCES_TO_MIGRATE) + ) + ): + value = RestrictedUnpickler(io.BytesIO(entry.value)).load() or {} + entry.value = bytes(json.dumps(value), encoding="utf-8") + + +def downgrade(): + bind = op.get_bind() + session: Session = db.Session(bind=bind) + for entry in paginated_update( + session.query(KeyValueEntry).filter( + KeyValueEntry.resource.in_(RESOURCES_TO_MIGRATE) + ), + ): + value = json.loads(entry.value) or {} + entry.value = pickle.dumps(value) diff --git a/superset/temporary_cache/api.py b/superset/temporary_cache/api.py index b6376c63c3b6f..85db65c62c8ac 100644 --- a/superset/temporary_cache/api.py +++ b/superset/temporary_cache/api.py @@ -24,6 +24,7 @@ from marshmallow import ValidationError from superset.constants import MODEL_API_RW_METHOD_PERMISSION_MAP, RouteMethod +from superset.key_value.types import JsonKeyValueCodec from superset.temporary_cache.commands.exceptions import ( TemporaryCacheAccessDeniedError, TemporaryCacheResourceNotFoundError, @@ -37,6 +38,8 @@ logger = logging.getLogger(__name__) +CODEC = JsonKeyValueCodec() + class TemporaryCacheRestApi(BaseSupersetApi, ABC): add_model_schema = TemporaryCachePostSchema() @@ -69,7 +72,12 @@ def post(self, pk: int) -> Response: try: item = self.add_model_schema.load(request.json) tab_id = request.args.get("tab_id") - args = CommandParameters(resource_id=pk, value=item["value"], tab_id=tab_id) + args = CommandParameters( + resource_id=pk, + value=item["value"], + tab_id=tab_id, + codec=CODEC, + ) key = self.get_create_command()(args).run() return self.response(201, key=key) except ValidationError as ex: @@ -89,6 +97,7 @@ def put(self, pk: int, key: str) -> Response: key=key, value=item["value"], tab_id=tab_id, + codec=CODEC, ) key = self.get_update_command()(args).run() return self.response(200, key=key) @@ -101,7 +110,7 @@ def put(self, pk: int, key: str) -> Response: def get(self, pk: int, key: str) -> Response: try: - args = CommandParameters(resource_id=pk, key=key) + args = CommandParameters(resource_id=pk, key=key, codec=CODEC) value = self.get_get_command()(args).run() if not value: return self.response_404() diff --git a/superset/temporary_cache/commands/parameters.py b/superset/temporary_cache/commands/parameters.py index 74b9c1c6321e5..e4e5b9b06a283 100644 --- a/superset/temporary_cache/commands/parameters.py +++ b/superset/temporary_cache/commands/parameters.py @@ -17,10 +17,13 @@ from dataclasses import dataclass from typing import Optional +from superset.key_value.types import KeyValueCodec + @dataclass class CommandParameters: resource_id: int + codec: Optional[KeyValueCodec] = None tab_id: Optional[int] = None key: Optional[str] = None value: Optional[str] = None diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index 22a36f41e1be5..4c6a3c12ddfdb 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. import json -import pickle from typing import Any, Dict, Iterator from uuid import uuid3 @@ -24,7 +23,7 @@ from superset import db from superset.key_value.models import KeyValueEntry -from superset.key_value.types import KeyValueResource +from superset.key_value.types import JsonKeyValueCodec, KeyValueResource from superset.key_value.utils import decode_permalink_id, encode_permalink_key from superset.models.slice import Slice from superset.utils.core import DatasourceType @@ -95,7 +94,7 @@ def test_get_missing_chart( chart_id = 1234 entry = KeyValueEntry( resource=KeyValueResource.EXPLORE_PERMALINK, - value=pickle.dumps( + value=JsonKeyValueCodec().encode( { "chartId": chart_id, "datasourceId": chart.datasource.id, diff --git a/tests/integration_tests/key_value/commands/create_test.py b/tests/integration_tests/key_value/commands/create_test.py index 0e789026baff4..a2ee3d13aed22 100644 --- a/tests/integration_tests/key_value/commands/create_test.py +++ b/tests/integration_tests/key_value/commands/create_test.py @@ -16,20 +16,23 @@ # under the License. from __future__ import annotations +import json import pickle -from uuid import UUID +import pytest from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User from superset.extensions import db +from superset.key_value.exceptions import KeyValueCreateFailedError from superset.utils.core import override_user from tests.integration_tests.key_value.commands.fixtures import ( admin, - ID_KEY, + JSON_CODEC, + JSON_VALUE, + PICKLE_CODEC, + PICKLE_VALUE, RESOURCE, - UUID_KEY, - VALUE, ) @@ -38,11 +41,15 @@ def test_create_id_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.models import KeyValueEntry with override_user(admin): - key = CreateKeyValueCommand(resource=RESOURCE, value=VALUE).run() + key = CreateKeyValueCommand( + resource=RESOURCE, + value=JSON_VALUE, + codec=JSON_CODEC, + ).run() entry = ( db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() ) - assert pickle.loads(entry.value) == VALUE + assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) db.session.commit() @@ -53,11 +60,43 @@ def test_create_uuid_entry(app_context: AppContext, admin: User) -> None: from superset.key_value.models import KeyValueEntry with override_user(admin): - key = CreateKeyValueCommand(resource=RESOURCE, value=VALUE).run() + key = CreateKeyValueCommand( + resource=RESOURCE, value=JSON_VALUE, codec=JSON_CODEC + ).run() entry = ( db.session.query(KeyValueEntry).filter_by(uuid=key.uuid).autoflush(False).one() ) - assert pickle.loads(entry.value) == VALUE + assert json.loads(entry.value) == JSON_VALUE assert entry.created_by_fk == admin.id db.session.delete(entry) db.session.commit() + + +def test_create_fail_json_entry(app_context: AppContext, admin: User) -> None: + from superset.key_value.commands.create import CreateKeyValueCommand + + with pytest.raises(KeyValueCreateFailedError): + CreateKeyValueCommand( + resource=RESOURCE, + value=PICKLE_VALUE, + codec=JSON_CODEC, + ).run() + + +def test_create_pickle_entry(app_context: AppContext, admin: User) -> None: + from superset.key_value.commands.create import CreateKeyValueCommand + from superset.key_value.models import KeyValueEntry + + with override_user(admin): + key = CreateKeyValueCommand( + resource=RESOURCE, + value=PICKLE_VALUE, + codec=PICKLE_CODEC, + ).run() + entry = ( + db.session.query(KeyValueEntry).filter_by(id=key.id).autoflush(False).one() + ) + assert type(pickle.loads(entry.value)) == type(PICKLE_VALUE) + assert entry.created_by_fk == admin.id + db.session.delete(entry) + db.session.commit() diff --git a/tests/integration_tests/key_value/commands/delete_test.py b/tests/integration_tests/key_value/commands/delete_test.py index 62f9883370cf1..3c4892faa6467 100644 --- a/tests/integration_tests/key_value/commands/delete_test.py +++ b/tests/integration_tests/key_value/commands/delete_test.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -import pickle +import json from typing import TYPE_CHECKING from uuid import UUID @@ -25,7 +25,11 @@ from flask_appbuilder.security.sqla.models import User from superset.extensions import db -from tests.integration_tests.key_value.commands.fixtures import admin, RESOURCE, VALUE +from tests.integration_tests.key_value.commands.fixtures import ( + admin, + JSON_VALUE, + RESOURCE, +) if TYPE_CHECKING: from superset.key_value.models import KeyValueEntry @@ -42,7 +46,7 @@ def key_value_entry() -> KeyValueEntry: id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, - value=pickle.dumps(VALUE), + value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), ) db.session.add(entry) db.session.commit() @@ -55,7 +59,6 @@ def test_delete_id_entry( key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand - from superset.key_value.models import KeyValueEntry assert DeleteKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() is True @@ -66,7 +69,6 @@ def test_delete_uuid_entry( key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand - from superset.key_value.models import KeyValueEntry assert DeleteKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() is True @@ -77,6 +79,5 @@ def test_delete_entry_missing( key_value_entry: KeyValueEntry, ) -> None: from superset.key_value.commands.delete import DeleteKeyValueCommand - from superset.key_value.models import KeyValueEntry assert DeleteKeyValueCommand(resource=RESOURCE, key=456).run() is False diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index 2fd4fde4e1dc3..66aea8a4edd27 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -17,7 +17,7 @@ from __future__ import annotations -import pickle +import json from typing import Generator, TYPE_CHECKING from uuid import UUID @@ -26,7 +26,11 @@ from sqlalchemy.orm import Session from superset.extensions import db -from superset.key_value.types import KeyValueResource +from superset.key_value.types import ( + JsonKeyValueCodec, + KeyValueResource, + PickleKeyValueCodec, +) from tests.integration_tests.test_app import app if TYPE_CHECKING: @@ -35,7 +39,10 @@ ID_KEY = 123 UUID_KEY = UUID("3e7a2ab8-bcaf-49b0-a5df-dfb432f291cc") RESOURCE = KeyValueResource.APP -VALUE = {"foo": "bar"} +JSON_VALUE = {"foo": "bar"} +PICKLE_VALUE = object() +JSON_CODEC = JsonKeyValueCodec() +PICKLE_CODEC = PickleKeyValueCodec() @pytest.fixture @@ -46,7 +53,7 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]: id=ID_KEY, uuid=UUID_KEY, resource=RESOURCE, - value=pickle.dumps(VALUE), + value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), ) db.session.add(entry) db.session.commit() diff --git a/tests/integration_tests/key_value/commands/get_test.py b/tests/integration_tests/key_value/commands/get_test.py index b1800a4c3b9a3..28a6dd73d5f04 100644 --- a/tests/integration_tests/key_value/commands/get_test.py +++ b/tests/integration_tests/key_value/commands/get_test.py @@ -16,7 +16,7 @@ # under the License. from __future__ import annotations -import pickle +import json import uuid from datetime import datetime, timedelta from typing import TYPE_CHECKING @@ -26,10 +26,11 @@ from superset.extensions import db from tests.integration_tests.key_value.commands.fixtures import ( ID_KEY, + JSON_CODEC, + JSON_VALUE, key_value_entry, RESOURCE, UUID_KEY, - VALUE, ) if TYPE_CHECKING: @@ -39,8 +40,8 @@ def test_get_id_entry(app_context: AppContext, key_value_entry: KeyValueEntry) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() - assert value == VALUE + value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run() + assert value == JSON_VALUE def test_get_uuid_entry( @@ -48,8 +49,8 @@ def test_get_uuid_entry( ) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY).run() - assert value == VALUE + value = GetKeyValueCommand(resource=RESOURCE, key=UUID_KEY, codec=JSON_CODEC).run() + assert value == JSON_VALUE def test_get_id_entry_missing( @@ -58,7 +59,7 @@ def test_get_id_entry_missing( ) -> None: from superset.key_value.commands.get import GetKeyValueCommand - value = GetKeyValueCommand(resource=RESOURCE, key=456).run() + value = GetKeyValueCommand(resource=RESOURCE, key=456, codec=JSON_CODEC).run() assert value is None @@ -70,12 +71,12 @@ def test_get_expired_entry(app_context: AppContext) -> None: id=678, uuid=uuid.uuid4(), resource=RESOURCE, - value=pickle.dumps(VALUE), + value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), expires_on=datetime.now() - timedelta(days=1), ) db.session.add(entry) db.session.commit() - value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY).run() + value = GetKeyValueCommand(resource=RESOURCE, key=ID_KEY, codec=JSON_CODEC).run() assert value is None db.session.delete(entry) db.session.commit() @@ -90,12 +91,12 @@ def test_get_future_expiring_entry(app_context: AppContext) -> None: id=id_, uuid=uuid.uuid4(), resource=RESOURCE, - value=pickle.dumps(VALUE), + value=bytes(json.dumps(JSON_VALUE), encoding="utf-8"), expires_on=datetime.now() + timedelta(days=1), ) db.session.add(entry) db.session.commit() - value = GetKeyValueCommand(resource=RESOURCE, key=id_).run() - assert value == VALUE + value = GetKeyValueCommand(resource=RESOURCE, key=id_, codec=JSON_CODEC).run() + assert value == JSON_VALUE db.session.delete(entry) db.session.commit() diff --git a/tests/integration_tests/key_value/commands/update_test.py b/tests/integration_tests/key_value/commands/update_test.py index 8eb03b4eda9eb..2c0fc3e31de51 100644 --- a/tests/integration_tests/key_value/commands/update_test.py +++ b/tests/integration_tests/key_value/commands/update_test.py @@ -16,9 +16,8 @@ # under the License. from __future__ import annotations -import pickle +import json from typing import TYPE_CHECKING -from uuid import UUID from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User @@ -28,6 +27,7 @@ from tests.integration_tests.key_value.commands.fixtures import ( admin, ID_KEY, + JSON_CODEC, key_value_entry, RESOURCE, UUID_KEY, @@ -53,11 +53,12 @@ def test_update_id_entry( resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is not None assert key.id == ID_KEY entry = db.session.query(KeyValueEntry).filter_by(id=ID_KEY).autoflush(False).one() - assert pickle.loads(entry.value) == NEW_VALUE + assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -74,13 +75,14 @@ def test_update_uuid_entry( resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is not None assert key.uuid == UUID_KEY entry = ( db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() ) - assert pickle.loads(entry.value) == NEW_VALUE + assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -92,5 +94,6 @@ def test_update_missing_entry(app_context: AppContext, admin: User) -> None: resource=RESOURCE, key=456, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is None diff --git a/tests/integration_tests/key_value/commands/upsert_test.py b/tests/integration_tests/key_value/commands/upsert_test.py index e5cd27e3a6cc8..c26b66d02e7bf 100644 --- a/tests/integration_tests/key_value/commands/upsert_test.py +++ b/tests/integration_tests/key_value/commands/upsert_test.py @@ -16,9 +16,8 @@ # under the License. from __future__ import annotations -import pickle +import json from typing import TYPE_CHECKING -from uuid import UUID from flask.ctx import AppContext from flask_appbuilder.security.sqla.models import User @@ -28,6 +27,7 @@ from tests.integration_tests.key_value.commands.fixtures import ( admin, ID_KEY, + JSON_CODEC, key_value_entry, RESOURCE, UUID_KEY, @@ -53,13 +53,14 @@ def test_upsert_id_entry( resource=RESOURCE, key=ID_KEY, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is not None assert key.id == ID_KEY entry = ( db.session.query(KeyValueEntry).filter_by(id=int(ID_KEY)).autoflush(False).one() ) - assert pickle.loads(entry.value) == NEW_VALUE + assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -76,13 +77,14 @@ def test_upsert_uuid_entry( resource=RESOURCE, key=UUID_KEY, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is not None assert key.uuid == UUID_KEY entry = ( db.session.query(KeyValueEntry).filter_by(uuid=UUID_KEY).autoflush(False).one() ) - assert pickle.loads(entry.value) == NEW_VALUE + assert json.loads(entry.value) == NEW_VALUE assert entry.changed_by_fk == admin.id @@ -95,6 +97,7 @@ def test_upsert_missing_entry(app_context: AppContext, admin: User) -> None: resource=RESOURCE, key=456, value=NEW_VALUE, + codec=JSON_CODEC, ).run() assert key is not None assert key.id == 456