Skip to content

Commit

Permalink
pyexcel/pyexcel-xls#11, ignore cases in file extension
Browse files Browse the repository at this point in the history
  • Loading branch information
chfw committed Nov 28, 2016
1 parent e990a9f commit a017add
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 5 deletions.
14 changes: 9 additions & 5 deletions pyexcel_io/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def get_io(file_type):
:param file_type: a supported file type
:returns: a appropriate io stream, None otherwise
"""
if file_type in text_stream_types:
__file_type = file_type.lower()

if __file_type in text_stream_types:
return StringIO()
elif file_type in binary_stream_types:
elif __file_type in binary_stream_types:
return BytesIO()
else:
return None
Expand All @@ -79,9 +81,11 @@ def get_io_type(file_type):
:param file_type: a supported file type
:returns: a appropriate io stream, None otherwise
"""
if file_type in text_stream_types:
__file_type = file_type.lower()

if __file_type in text_stream_types:
return "string"
elif file_type in binary_stream_types:
elif __file_type in binary_stream_types:
return "bytes"
else:
return None
Expand Down Expand Up @@ -121,7 +125,7 @@ def create_writer(file_type, library=None):

def _get_a_handler(factories, file_type, library):
if file_type in factories:
handler_dict = factories[file_type]
handler_dict = factories[file_type.lower()]
if library is not None:
handler_class = handler_dict.get(library, None)
if handler_class is None:
Expand Down
21 changes: 21 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,17 @@ def test_default_csv_format():
assert result['csv'] == [[1, 2, 3]]


def test_case_insentivity():
data = [['1', '2', '3']]
io = manager.get_io("CSV")
# test default format for saving is 'csv'
save_data(io, data)
io.seek(0)
# test default format for reading is 'csv'
result = get_data(io)
assert result['csv'] == [[1, 2, 3]]


def test_file_handle_as_input():
test_file = "file_handle.csv"
with open(test_file, 'w') as f:
Expand All @@ -145,6 +156,16 @@ def test_file_handle_as_input():
eq_(data['csv'], [[1, 2, 3]])


def test_file_type_case_insensitivity():
test_file = "file_handle.CSv"
with open(test_file, 'w') as f:
f.write("1,2,3")

with open(test_file, 'r') as f:
data = get_data(f, 'csv')
eq_(data['csv'], [[1, 2, 3]])


def test_file_handle_as_output():
test_file = "file_handle.csv"
with open(test_file, 'w') as f:
Expand Down

0 comments on commit a017add

Please sign in to comment.