From c581ea8661dd229f42d5959a6745a26ce8f03289 Mon Sep 17 00:00:00 2001 From: rumbin Date: Thu, 6 Apr 2017 18:42:43 +0200 Subject: [PATCH] Alternative PR for: Some bytes/str issues in py3 w/ zlib and json (#2558) * sql_lab.py: compress via utils * utils.py: added zlib_compress and zlib_compress_to_string * core.py: converted to use zlib_decompress_to_string; renamed uncompress to decompress in utils.py * utils_tests.py: added test for compress/decompress * fixed broken utils test; removed redundant code and empty lines from utils.py * utils.py: corrected docstrings, removed unnecessary 'else' * removed yet another superfluous else --- superset/sql_lab.py | 3 +-- superset/utils.py | 35 ++++++++++++++++++++++++++++++++++- superset/views/core.py | 5 ++--- tests/utils_tests.py | 9 ++++++++- 4 files changed, 45 insertions(+), 7 deletions(-) diff --git a/superset/sql_lab.py b/superset/sql_lab.py index cc42ed96fb389..f3dcbf120b65b 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -6,7 +6,6 @@ import pandas as pd import sqlalchemy import uuid -import zlib from sqlalchemy.pool import NullPool from sqlalchemy.orm import sessionmaker @@ -185,7 +184,7 @@ def handle_error(msg): if store_results: key = '{}'.format(uuid.uuid4()) logging.info("Storing results in results backend, key: {}".format(key)) - results_backend.set(key, zlib.compress(payload)) + results_backend.set(key, utils.zlib_compress(payload)) query.results_key = key session.merge(query) diff --git a/superset/utils.py b/superset/utils.py index 984f9c6834049..ec4c18d89681c 100644 --- a/superset/utils.py +++ b/superset/utils.py @@ -16,6 +16,8 @@ import sqlalchemy as sa import signal import uuid +import sys +import zlib from builtins import object from datetime import date, datetime, time @@ -41,7 +43,7 @@ logging.getLogger('MARKDOWN').setLevel(logging.INFO) - +PY3K = sys.version_info >= (3, 0) EPOCH = datetime(1970, 1, 1) DTTM_ALIAS = '__timestamp' @@ -572,3 +574,34 @@ def setup_cache(app, cache_config): """Setup the flask-cache on a flask app""" if cache_config and cache_config.get('CACHE_TYPE') != 'null': return Cache(app, config=cache_config) + + +def zlib_compress(data): + """ + Compress things in a py2/3 safe fashion + >>> json_str = '{"test": 1}' + >>> blob = zlib_compress(json_str) + """ + if PY3K: + if isinstance(data, str): + return zlib.compress(bytes(data, "utf-8")) + return zlib.compress(data) + return zlib.compress(data) + + +def zlib_decompress_to_string(blob): + """ + Decompress things to a string in a py2/3 safe fashion + >>> json_str = '{"test": 1}' + >>> blob = zlib_compress(json_str) + >>> got_str = zlib_decompress_to_string(blob) + >>> got_str == json_str + True + """ + if PY3K: + if isinstance(blob, bytes): + decompressed = zlib.decompress(blob) + else: + decompressed = zlib.decompress(bytes(blob, "utf-8")) + return decompressed.decode("utf-8") + return zlib.decompress(blob) diff --git a/superset/views/core.py b/superset/views/core.py index a3b708f2a65a8..22da33f388ea9 100755 --- a/superset/views/core.py +++ b/superset/views/core.py @@ -11,7 +11,6 @@ import re import time import traceback -import zlib import sqlalchemy as sqla @@ -1878,7 +1877,7 @@ def results(self, key): return json_error_response(get_datasource_access_error_msg( '{}'.format(rejected_tables))) - payload = zlib.decompress(blob) + payload = utils.zlib_decompress_to_string(blob) display_limit = app.config.get('DISPLAY_SQL_MAX_ROW', None) if display_limit: payload_json = json.loads(payload) @@ -2018,7 +2017,7 @@ def csv(self, client_id): if results_backend and query.results_key: blob = results_backend.get(query.results_key) if blob: - json_payload = zlib.decompress(blob) + json_payload = utils.zlib_decompress_to_string(blob) obj = json.loads(json_payload) columns = [c['name'] for c in obj['columns']] df = pd.DataFrame.from_records(obj['data'], columns=columns) diff --git a/tests/utils_tests.py b/tests/utils_tests.py index 231f03b84fe8a..e07d9594b3979 100644 --- a/tests/utils_tests.py +++ b/tests/utils_tests.py @@ -1,7 +1,7 @@ from datetime import datetime, date, timedelta, time from decimal import Decimal from superset.utils import ( - json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta + json_int_dttm_ser, json_iso_dttm_ser, base_json_conv, parse_human_timedelta, zlib_compress, zlib_decompress_to_string ) import unittest import uuid @@ -45,3 +45,10 @@ def test_base_json_conv(self): def test_parse_human_timedelta(self, mock_now): mock_now.return_value = datetime(2016, 12, 1) self.assertEquals(parse_human_timedelta('now'), timedelta(0)) + + def test_zlib_compression(self): + json_str = """{"test": 1}""" + blob = zlib_compress(json_str) + got_str = zlib_decompress_to_string(blob) + self.assertEquals(json_str, got_str) +