Skip to content

Commit

Permalink
allow csv upload to accept parquet file
Browse files Browse the repository at this point in the history
  • Loading branch information
exemplary-citizen committed May 3, 2021
1 parent 4d2c932 commit 4cc378a
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 22 deletions.
3 changes: 2 additions & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -520,7 +520,8 @@ def _try_json_readsha( # pylint: disable=unused-argument
# Allowed format types for upload on Database view
EXCEL_EXTENSIONS = {"xlsx", "xls"}
CSV_EXTENSIONS = {"csv", "tsv", "txt"}
ALLOWED_EXTENSIONS = {*EXCEL_EXTENSIONS, *CSV_EXTENSIONS}
OTHER_EXTENSIONS = {"parquet"}
ALLOWED_EXTENSIONS = {*EXCEL_EXTENSIONS, *CSV_EXTENSIONS, *OTHER_EXTENSIONS}

# CSV Options: key/value pairs that will be passed as argument to DataFrame.to_csv
# method.
Expand Down
13 changes: 12 additions & 1 deletion superset/views/database/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,9 @@ def at_least_one_schema_is_allowed(database: Database) -> bool:
validators=[
FileRequired(),
FileAllowed(
config["ALLOWED_EXTENSIONS"].intersection(config["CSV_EXTENSIONS"]),
config["ALLOWED_EXTENSIONS"].intersection(
config["CSV_EXTENSIONS"].union(config["OTHER_EXTENSIONS"])
),
_(
"Only the following file extensions are allowed: "
"%(allowed_extensions)s",
Expand Down Expand Up @@ -163,6 +165,15 @@ def at_least_one_schema_is_allowed(database: Database) -> bool:
_("Mangle Duplicate Columns"),
description=_('Specify duplicate columns as "X.0, X.1".'),
)
usecols = JsonListField(
_("Use Columns"),
default=None,
description=_(
"Json list of the column names that should be read. "
"If not None, only these columns will be read from the file."
),
validators=[Optional()],
)
skipinitialspace = BooleanField(
_("Skip Initial Space"), description=_("Skip spaces after delimiter.")
)
Expand Down
48 changes: 28 additions & 20 deletions superset/views/database/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,32 @@ def form_get(self, form: CsvToDatabaseForm) -> None:
def form_post(self, form: CsvToDatabaseForm) -> Response:
database = form.con.data
csv_table = Table(table=form.name.data, schema=form.schema.data)
file_type = form.csv_file.data.filename.split(".")[-1]
if file_type == "parquet":
read = pd.read_parquet
kwargs = {
"columns": form.usecols.data,
}
else:
read = pd.read_csv
kwargs = {
"chunksize": 1000,
"encoding": "utf-8",
"header": form.header.data if form.header.data else 0,
"index_col": form.index_col.data,
"infer_datetime_format": form.infer_datetime_format.data,
"iterator": True,
"keep_default_na": not form.null_values.data,
"mangle_dupe_cols": form.mangle_dupe_cols.data,
"usecols": form.usecols.data,
"na_values": form.null_values.data if form.null_values.data else None,
"nrows": form.nrows.data,
"parse_dates": form.parse_dates.data,
"sep": form.sep.data,
"skip_blank_lines": form.skip_blank_lines.data,
"skipinitialspace": form.skipinitialspace.data,
"skiprows": form.skiprows.data,
}

if not schema_allows_csv_upload(database, csv_table.schema):
message = _(
Expand All @@ -151,26 +177,8 @@ def form_post(self, form: CsvToDatabaseForm) -> Response:
return redirect("/csvtodatabaseview/form")

try:
df = pd.concat(
pd.read_csv(
chunksize=1000,
encoding="utf-8",
filepath_or_buffer=form.csv_file.data,
header=form.header.data if form.header.data else 0,
index_col=form.index_col.data,
infer_datetime_format=form.infer_datetime_format.data,
iterator=True,
keep_default_na=not form.null_values.data,
mangle_dupe_cols=form.mangle_dupe_cols.data,
na_values=form.null_values.data if form.null_values.data else None,
nrows=form.nrows.data,
parse_dates=form.parse_dates.data,
sep=form.sep.data,
skip_blank_lines=form.skip_blank_lines.data,
skipinitialspace=form.skipinitialspace.data,
skiprows=form.skiprows.data,
)
)
chunks = read(form.csv_file.data, **kwargs)
df = pd.concat(chunks) if isinstance(chunks, list) else chunks

database = (
db.session.query(models.Database)
Expand Down
51 changes: 51 additions & 0 deletions tests/csv_upload_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
CSV_FILENAME1 = "testCSV1.csv"
CSV_FILENAME2 = "testCSV2.csv"
EXCEL_FILENAME = "testExcel.xlsx"
PARQUET_FILENAME = "testParquet.parquet"

EXCEL_UPLOAD_TABLE = "excel_upload"
CSV_UPLOAD_TABLE = "csv_upload"
Expand Down Expand Up @@ -90,6 +91,12 @@ def create_csv_files():
os.remove(CSV_FILENAME2)


def create_parquet_files():
pd.DataFrame({"a": ["john", "paul"], "b": [1, 2]}).to_parquet(PARQUET_FILENAME)
yield
os.remove(PARQUET_FILENAME)


@pytest.fixture()
def create_excel_files():
pd.DataFrame({"a": ["john", "paul"], "b": [1, 2]}).to_excel(EXCEL_FILENAME)
Expand Down Expand Up @@ -328,3 +335,47 @@ def test_import_excel(setup_csv_upload, create_excel_files):
.fetchall()
)
assert data == [(0, "john", 1), (1, "paul", 2)]


@mock.patch("superset.db_engine_specs.hive.upload_to_s3", mock_upload_to_s3)
def test_import_parquet(setup_csv_upload, create_parquet_files):
if utils.backend() == "hive":
pytest.skip("Hive doesn't allow parquet upload.")

success_msg = (
f'CSV file "{PARQUET_FILENAME}" uploaded to table "{CSV_UPLOAD_TABLE}"'
)

# initial upload with fail mode
resp = upload_csv(PARQUET_FILENAME, CSV_UPLOAD_TABLE)
assert success_msg in resp

# upload again with fail mode; should fail
fail_msg = (
f'Unable to upload CSV file "{PARQUET_FILENAME}" to table "{CSV_UPLOAD_TABLE}"'
)
resp = upload_csv(PARQUET_FILENAME, CSV_UPLOAD_TABLE)
assert fail_msg in resp

if utils.backend() != "hive":
# upload again with append mode
resp = upload_csv(
PARQUET_FILENAME, CSV_UPLOAD_TABLE, extra={"if_exists": "append"}
)
assert success_msg in resp

# upload again with replace mode
resp = upload_csv(
PARQUET_FILENAME, CSV_UPLOAD_TABLE, extra={"if_exists": "replace"}
)
assert success_msg in resp

# make sure that john and empty string are replaced with None
data = (
get_upload_db()
.get_sqla_engine()
.execute(f"SELECT * from {CSV_UPLOAD_TABLE}")
.fetchall()
)
print(data)
assert data == [(0, "john", 1), (1, "paul", 2)]

0 comments on commit 4cc378a

Please sign in to comment.