Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

Commit

Permalink
Merge pull request #117 from airbnb/more_csv_changes
Browse files Browse the repository at this point in the history
Cherrypick more csv changes
  • Loading branch information
timifasubaa authored Sep 20, 2018
2 parents f48eddf + f939b7a commit bb2b51b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 23 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ six==1.11.0
sqlalchemy==1.2.2
sqlalchemy-utils==0.32.21
sqlparse==0.2.4
tableschema==1.1.0
thrift==0.11.0
thrift-sasl==0.3.0
unicodecsv==0.14.1
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def get_git_sha():
'sqlalchemy',
'sqlalchemy-utils',
'sqlparse',
'tableschema',
'thrift>=0.9.3',
'thrift-sasl>=0.2.1',
'unicodecsv',
Expand Down
58 changes: 35 additions & 23 deletions superset/db_engine_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from sqlalchemy.engine.url import make_url
from sqlalchemy.sql import text
import sqlparse
import unicodecsv
from tableschema import Table
from werkzeug.utils import secure_filename

from superset import app, cache_util, conf, db, utils
Expand Down Expand Up @@ -90,7 +90,7 @@ def extra_table_metadata(cls, database, table_name, schema_name):
@staticmethod
def csv_to_df(**kwargs):
kwargs['filepath_or_buffer'] = \
app.config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
config['UPLOAD_FOLDER'] + kwargs['filepath_or_buffer']
kwargs['encoding'] = 'utf-8'
kwargs['iterator'] = True
chunks = pandas.read_csv(**kwargs)
Expand All @@ -112,7 +112,7 @@ def create_table_from_csv(form, table):
def _allowed_file(filename):
# Only allow specific file extensions as specified in the config
extension = os.path.splitext(filename)[1]
return extension and extension[1:] in app.config['ALLOWED_EXTENSIONS']
return extension and extension[1:] in config['ALLOWED_EXTENSIONS']

filename = secure_filename(form.csv_file.data.filename)
if not _allowed_file(filename):
Expand Down Expand Up @@ -905,9 +905,22 @@ def fetch_result_sets(cls, db, datasource_type, force=False):
@staticmethod
def create_table_from_csv(form, table):
"""Uploads a csv file and creates a superset datasource in Hive."""
def get_column_names(filepath):
with open(filepath, 'rb') as f:
return next(unicodecsv.reader(f, encoding='utf-8-sig'))
def convert_to_hive_type(col_type):
"""maps tableschema's types to hive types"""
tableschema_to_hive_types = {
'boolean': 'BOOLEAN',
'integer': 'INT',
'number': 'DOUBLE',
'string': 'STRING',
}
return tableschema_to_hive_types.get(col_type, 'STRING')

bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']

if not bucket_path:
logging.info('No upload bucket specified')
raise Exception(
'No upload bucket specified. You can specify one in the config file.')

table_name = form.name.data
schema_name = form.schema.data
Expand All @@ -918,39 +931,38 @@ def get_column_names(filepath):
"You can't specify a namespace. "
'All tables will be uploaded to the `{}` namespace'.format(
config.get('HIVE_NAMESPACE')))
table_name = '{}.{}'.format(
full_table_name = '{}.{}'.format(
config.get('UPLOADED_CSV_HIVE_NAMESPACE'), table_name)
else:
if '.' in table_name and schema_name:
raise Exception(
"You can't specify a namespace both in the name of the table "
'and in the schema field. Please remove one')
if schema_name:
table_name = '{}.{}'.format(schema_name, table_name)

filename = form.csv_file.data.filename
bucket_path = config['CSV_TO_HIVE_UPLOAD_S3_BUCKET']
full_table_name = '{}.{}'.format(
schema_name, table_name) if schema_name else table_name

if not bucket_path:
logging.info('No upload bucket specified')
raise Exception(
'No upload bucket specified. You can specify one in the config file.')
filename = form.csv_file.data.filename

upload_prefix = app.config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
dest_path = os.path.join(table_name, filename)
upload_prefix = config['CSV_TO_HIVE_UPLOAD_DIRECTORY']
upload_path = config['UPLOAD_FOLDER'] + \
secure_filename(filename)

upload_path = app.config['UPLOAD_FOLDER'] + \
secure_filename(form.csv_file.data.filename)
column_names = get_column_names(upload_path)
schema_definition = ', '.join(
[s + ' STRING ' for s in column_names])
hive_table_schema = Table(upload_path).infer()
column_name_and_type = []
for column_info in hive_table_schema['fields']:
column_name_and_type.append(
'`{}` {}'.format(
column_info['name'],
convert_to_hive_type(column_info['type'])))
schema_definition = ', '.join(column_name_and_type)

s3 = boto3.client('s3')
location = os.path.join('s3a://', bucket_path, upload_prefix, table_name)
s3.upload_file(
upload_path, bucket_path,
os.path.join(upload_prefix, table_name, filename))
sql = """CREATE TABLE {table_name} ( {schema_definition} )
sql = """CREATE TABLE {full_table_name} ( {schema_definition} )
ROW FORMAT DELIMITED FIELDS TERMINATED BY ',' STORED AS
TEXTFILE LOCATION '{location}'
tblproperties ('skip.header.line.count'='1')""".format(**locals())
Expand Down

0 comments on commit bb2b51b

Please sign in to comment.