diff --git a/requirements.txt b/requirements.txt index ea12de6a824ac..dbba7ec7e7766 100644 --- a/requirements.txt +++ b/requirements.txt @@ -34,6 +34,7 @@ six==1.11.0 sqlalchemy==1.2.2 sqlalchemy-utils==0.32.21 sqlparse==0.2.4 +tableschema==1.1.0 thrift==0.11.0 thrift-sasl==0.3.0 unicodecsv==0.14.1 diff --git a/setup.py b/setup.py index edb434ac3dab5..7713053740c2f 100644 --- a/setup.py +++ b/setup.py @@ -86,6 +86,7 @@ def get_git_sha(): 'sqlalchemy', 'sqlalchemy-utils', 'sqlparse', + 'tableschema', 'thrift>=0.9.3', 'thrift-sasl>=0.2.1', 'unicodecsv', diff --git a/superset/db_engine_specs.py b/superset/db_engine_specs.py index bf6ec99cbde43..e1e0025faed33 100644 --- a/superset/db_engine_specs.py +++ b/superset/db_engine_specs.py @@ -36,7 +36,7 @@ from sqlalchemy.engine.url import make_url from sqlalchemy.sql import text import sqlparse -import unicodecsv +from tableschema import Table from werkzeug.utils import secure_filename from superset import app, cache_util, conf, db, utils @@ -90,7 +90,7 @@ def extra_table_metadata(cls, database, table_name, schema_name): @staticmethod def csv_to_df(**kwargs): kwargs['filepath_or_buffer'] = \ - app.config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer'] + config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer'] kwargs['encoding'] = 'utf-8' kwargs['iterator'] = True chunks = pandas.read_csv(**kwargs) @@ -112,7 +112,7 @@ def create_table_from_csv(form, table): def _allowed_file(filename): # Only allow specific file extensions as specified in the config extension = os.path.splitext(filename)[1] - return extension and extension[1:] in app.config['ALLOWED_EXTENSIONS'] + return extension and extension[1:] in config['ALLOWED_EXTENSIONS'] filename = secure_filename(form.csv_file.data.filename) if not _allowed_file(filename): @@ -905,9 +905,22 @@ def fetch_result_sets(cls, db, datasource_type, force=False): @staticmethod def create_table_from_csv(form, table): """Uploads a csv file and creates a superset datasource in Hive.""" - def get_column_names(filepath): - with open(filepath, 'rb') as f: - return next(unicodecsv.reader(f, encoding='utf-8-sig')) + def convert_to_hive_type(col_type): + """maps tableschema's types to hive types""" + tableschema_to_hive_types = { + 'boolean': 'BOOLEAN', + 'integer': 'INT', + 'number': 'DOUBLE', + 'string': 'STRING', + } + return tableschema_to_hive_types.get(col_type, 'STRING') + + bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET'] + + if not bucket_path: + logging.info('No upload bucket specified') + raise Exception( + 'No upload bucket specified. You can specify one in the config file.') table_name = form.name.data schema_name = form.schema.data @@ -918,39 +931,38 @@ def get_column_names(filepath): "You can't specify a namespace. " 'All tables will be uploaded to the `{}` namespace'.format( config.get('HIVE_NAMESPACE'))) - table_name = '{}.{}'.format( + full_table_name = '{}.{}'.format( config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name) else: if '.' in table_name and schema_name: raise Exception( "You can't specify a namespace both in the name of the table " 'and in the schema field. Please remove one') - if schema_name: - table_name = '{}.{}'.format(schema_name, table_name) - filename = form.csv_file.data.filename - bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET'] + full_table_name = '{}.{}'.format( + schema_name, table_name) if schema_name else table_name - if not bucket_path: - logging.info('No upload bucket specified') - raise Exception( - 'No upload bucket specified. You can specify one in the config file.') + filename = form.csv_file.data.filename - upload_prefix = app.config['CSV_TO_HIVE_UPLOAD_DIRECTORY'] - dest_path = os.path.join(table_name, filename) + upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY'] + upload_path = config['UPLOAD_FOLDER'] + \ + secure_filename(filename) - upload_path = app.config['UPLOAD_FOLDER'] + \ - secure_filename(form.csv_file.data.filename) - column_names = get_column_names(upload_path) - schema_definition = ', '.join( - [s + ' STRING ' for s in column_names]) + hive_table_schema = Table(upload_path).infer() + column_name_and_type = [] + for column_info in hive_table_schema['fields']: + column_name_and_type.append( + '`{}` {}'.format( + column_info['name'], + convert_to_hive_type(column_info['type']))) + schema_definition = ', '.join(column_name_and_type) s3 = boto3.client('s3') location = os.path.join('s3a://', bucket_path, upload_prefix, table_name) s3.upload_file( upload_path, bucket_path, os.path.join(upload_prefix, table_name, filename)) - sql = """CREATE TABLE {table_name} ( {schema_definition} ) + sql = """CREATE TABLE {full_table_name} ( {schema_definition} ) ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS TEXTFILE LOCATION '{location}' tblproperties ('skip.header.line.count'='1')""".format(**locals())