From a017add8668e8513d9c0b30c958f2f7a35c6ae13 Mon Sep 17 00:00:00 2001 From: chfw Date: Mon, 28 Nov 2016 18:01:53 +0000 Subject: [PATCH] https://github.com/pyexcel/pyexcel-xls/issues/11, ignore cases in file extension --- pyexcel_io/manager.py | 14 +++++++++----- tests/test_io.py | 21 +++++++++++++++++++++ 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/pyexcel_io/manager.py b/pyexcel_io/manager.py index 8e13ab2..69890e5 100644 --- a/pyexcel_io/manager.py +++ b/pyexcel_io/manager.py @@ -65,9 +65,11 @@ def get_io(file_type): :param file_type: a supported file type :returns: a appropriate io stream, None otherwise """ - if file_type in text_stream_types: + __file_type = file_type.lower() + + if __file_type in text_stream_types: return StringIO() - elif file_type in binary_stream_types: + elif __file_type in binary_stream_types: return BytesIO() else: return None @@ -79,9 +81,11 @@ def get_io_type(file_type): :param file_type: a supported file type :returns: a appropriate io stream, None otherwise """ - if file_type in text_stream_types: + __file_type = file_type.lower() + + if __file_type in text_stream_types: return "string" - elif file_type in binary_stream_types: + elif __file_type in binary_stream_types: return "bytes" else: return None @@ -121,7 +125,7 @@ def create_writer(file_type, library=None): def _get_a_handler(factories, file_type, library): if file_type in factories: - handler_dict = factories[file_type] + handler_dict = factories[file_type.lower()] if library is not None: handler_class = handler_dict.get(library, None) if handler_class is None: diff --git a/tests/test_io.py b/tests/test_io.py index 6e9a1fd..85def44 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -135,6 +135,17 @@ def test_default_csv_format(): assert result['csv'] == [[1, 2, 3]] +def test_case_insentivity(): + data = [['1', '2', '3']] + io = manager.get_io("CSV") + # test default format for saving is 'csv' + save_data(io, data) + io.seek(0) + # test default format for reading is 'csv' + result = get_data(io) + assert result['csv'] == [[1, 2, 3]] + + def test_file_handle_as_input(): test_file = "file_handle.csv" with open(test_file, 'w') as f: @@ -145,6 +156,16 @@ def test_file_handle_as_input(): eq_(data['csv'], [[1, 2, 3]]) +def test_file_type_case_insensitivity(): + test_file = "file_handle.CSv" + with open(test_file, 'w') as f: + f.write("1,2,3") + + with open(test_file, 'r') as f: + data = get_data(f, 'csv') + eq_(data['csv'], [[1, 2, 3]]) + + def test_file_handle_as_output(): test_file = "file_handle.csv" with open(test_file, 'w') as f: