Skip to content

Commit

Permalink
Refactor import csv (apache#4298)
Browse files Browse the repository at this point in the history
* move helpers to utils

* make form use queryselector

* refactor exception throwing and handling

* update db_connection access point

* nits

(cherry picked from commit 6d37d97)
  • Loading branch information
timifasubaa authored and Grace Guo committed Feb 5, 2018
1 parent 4a135c5 commit ff2bff9
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 53 deletions.
3 changes: 2 additions & 1 deletion superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,14 @@ def _allowed_file(filename):
'table': table,
'df': df,
'name': form.name.data,
'con': create_engine(form.con.data, echo=False),
'con': create_engine(form.con.data.sqlalchemy_uri, echo=False),
'schema': form.schema.data,
'if_exists': form.if_exists.data,
'index': form.index.data,
'index_label': form.index_label.data,
'chunksize': 10000,
}

BaseEngineSpec.df_to_db(**df_to_db_kwargs)

@classmethod
Expand Down
18 changes: 10 additions & 8 deletions superset/forms.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
from flask_wtf.file import FileAllowed, FileField, FileRequired
from wtforms import (
BooleanField, IntegerField, SelectField, StringField)
from wtforms.ext.sqlalchemy.fields import QuerySelectField
from wtforms.validators import DataRequired, NumberRange, Optional

from superset import app
from superset import app, db
from superset.models import core as models

config = app.config


class CsvToDatabaseForm(DynamicForm):
# pylint: disable=E0211
def all_db_items():
return db.session.query(models.Database)

name = StringField(
_('Table Name'),
description=_('Name of table to be created from csv data.'),
Expand All @@ -28,12 +34,9 @@ class CsvToDatabaseForm(DynamicForm):
description=_('Select a CSV file to be uploaded to a database.'),
validators=[
FileRequired(), FileAllowed(['csv'], _('CSV Files Only!'))])

con = SelectField(
_('Database'),
description=_('database in which to add above table.'),
validators=[DataRequired()],
choices=[])
con = QuerySelectField(
query_factory=all_db_items,
get_pk=lambda a: a.id, get_label=lambda a: a.database_name)
sep = StringField(
_('Delimiter'),
description=_('Delimiter used by CSV file (for whitespace use \s+).'),
Expand All @@ -49,7 +52,6 @@ class CsvToDatabaseForm(DynamicForm):
('fail', _('Fail')), ('replace', _('Replace')),
('append', _('Append'))],
validators=[DataRequired()])

schema = StringField(
_('Schema'),
description=_('Specify a schema (if database flavour supports this).'),
Expand Down
62 changes: 24 additions & 38 deletions superset/views/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,11 @@
from flask_babel import gettext as __
from flask_babel import lazy_gettext as _
import pandas as pd
from six import text_type
import sqlalchemy as sqla
from sqlalchemy import create_engine
from sqlalchemy.engine.url import make_url
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import IntegrityError, OperationalError
from unidecode import unidecode
from werkzeug.routing import BaseConverter
from werkzeug.utils import secure_filename
Expand Down Expand Up @@ -163,8 +164,6 @@ def apply(self, query, func): # noqa
return query




class DatabaseView(SupersetModelView, DeleteMixin, YamlExportMixin): # noqa
datamodel = SQLAInterface(models.Database)

Expand Down Expand Up @@ -319,49 +318,36 @@ def form_get(self, form):
form.infer_datetime_format.data = True
form.decimal.data = '.'
form.if_exists.data = 'append'
all_datasources = (
db.session.query(
models.Database.sqlalchemy_uri,
models.Database.database_name)
.all()
)
form.con.choices += all_datasources

def form_post(self, form):
def _upload_file(csv_file):
if csv_file and csv_file.filename:
filename = secure_filename(csv_file.filename)
csv_file.save(os.path.join(config['UPLOAD_FOLDER'], filename))
return filename

csv_file = form.csv_file.data
_upload_file(csv_file)
table = SqlaTable(table_name=form.name.data)
database = (
db.session.query(models.Database)
.filter_by(sqlalchemy_uri=form.data.get('con'))
.one()
)
table.database = database
table.database_id = database.id
form.csv_file.data.filename = secure_filename(form.csv_file.data.filename)
csv_filename = form.csv_file.data.filename
try:
database.db_engine_spec.create_table_from_csv(form, table)
csv_file.save(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
table = SqlaTable(table_name=form.name.data)
table.database = form.data.get('con')
table.database_id = table.database.id
table.database.db_engine_spec.create_table_from_csv(form, table)
except Exception as e:
os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename))
flash(e, 'error')
return redirect('/tablemodelview/list/')
try:
os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
except OSError:
pass
message = u'Table name {} already exists. Please pick another'.format(
form.name.data) if isinstance(e, IntegrityError) else text_type(e)
flash(
message,
'danger')
return redirect('/csvtodatabaseview/form')

os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_file.filename))
os.remove(os.path.join(config['UPLOAD_FOLDER'], csv_filename))
# Go back to welcome page / splash screen
db_name = (
db.session.query(models.Database.database_name)
.filter_by(sqlalchemy_uri=form.data.get('con'))
.one()
)
message = _('CSV file "{0}" uploaded to table "{1}" in '
'database "{2}"'.format(form.csv_file.data.filename,
db_name = table.database.database_name
message = _(u'CSV file "{0}" uploaded to table "{1}" in '
'database "{2}"'.format(csv_filename,
form.name.data,
db_name[0]))
db_name))
flash(message, 'info')
return redirect('/tablemodelview/list/')

Expand Down
14 changes: 8 additions & 6 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,20 +804,22 @@ def test_import_csv(self):
test_file.write('john,1\n')
test_file.write('paul,2\n')
test_file.close()
main_db_uri = db.session.query(
models.Database.sqlalchemy_uri)\
.filter_by(database_name='main').all()
main_db_uri = (
db.session.query(models.Database)
.filter_by(database_name='main')
.all()
)

test_file = open(filename, 'rb')
form_data = {
'csv_file': test_file,
'sep': ',',
'name': table_name,
'con': main_db_uri[0][0],
'con': main_db_uri[0].id,
'if_exists': 'append',
'index_label': 'test_label',
'mangle_dupe_cols': False}

'mangle_dupe_cols': False,
}
url = '/databaseview/list/'
add_datasource_page = self.get_resp(url)
assert 'Upload a CSV' in add_datasource_page
Expand Down

0 comments on commit ff2bff9

Please sign in to comment.