Skip to content

Commit

Permalink
Properly close opened files in XportReader and SAS7BDATReader
Browse files Browse the repository at this point in the history
  • Loading branch information
agraboso committed Aug 10, 2016
1 parent 240383c commit 7aa5184
Show file tree
Hide file tree
Showing 6 changed files with 36 additions and 3 deletions.
4 changes: 2 additions & 2 deletions doc/source/whatsnew/v0.19.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -902,8 +902,8 @@ Bug Fixes
- Bug in ``pd.read_csv()`` that prevents ``usecols`` from being an empty set (:issue:`13402`)
- Bug in ``pd.read_csv()`` with ``engine='c'`` in which null ``quotechar`` was not accepted even though ``quoting`` was specified as ``None`` (:issue:`13411`)
- Bug in ``pd.read_csv()`` with ``engine='c'`` in which fields were not properly cast to float when quoting was specified as non-numeric (:issue:`13411`)
- Bug in ``pd.read_csv``, ``pd.read_table`` and ``pd.read_stata`` where files were opened by parsers but not closed if both ``chunksize`` and ``iterator`` were ``None``. (:issue:`13940`)
- Bug in ``StataReader`` and ``StataWriter`` where a file was not properly closed when an error was raised. (:issue:`13940`)
- Bug in ``pd.read_csv``, ``pd.read_table``, ``pd.read_fwf``, ``pd.read_stata`` and ``pd.read_sas`` where files were opened by parsers but not closed if both ``chunksize`` and ``iterator`` were ``None``. (:issue:`13940`)
- Bug in ``StataReader``, ``StataWriter``, ``XportReader`` and ``SAS7BDATReader`` where a file was not properly closed when an error was raised. (:issue:`13940`)

- Bug in ``pd.pivot_table()`` where ``margins_name`` is ignored when ``aggfunc`` is a list (:issue:`13354`)
- Bug in ``pd.Series.str.zfill``, ``center``, ``ljust``, ``rjust``, and ``pad`` when passing non-integers, did not raise ``TypeError`` (:issue:`13598`)
Expand Down
17 changes: 17 additions & 0 deletions pandas/io/sas/sas7bdat.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,16 +92,24 @@ def __init__(self, path_or_buf, index=None, convert_dates=True,
self._path_or_buf, _, _ = get_filepath_or_buffer(path_or_buf)
if isinstance(self._path_or_buf, compat.string_types):
self._path_or_buf = open(self._path_or_buf, 'rb')
self.handle = self._path_or_buf

self._get_properties()
self._parse_metadata()

def close(self):
try:
self.handle.close()
except AttributeError:
pass

def _get_properties(self):

# Check magic number
self._path_or_buf.seek(0)
self._cached_page = self._path_or_buf.read(288)
if self._cached_page[0:len(const.magic)] != const.magic:
self.close()
raise ValueError("magic number mismatch (not a SAS file?)")

# Get alignment information
Expand Down Expand Up @@ -175,6 +183,7 @@ def _get_properties(self):
buf = self._path_or_buf.read(self.header_length - 288)
self._cached_page += buf
if len(self._cached_page) != self.header_length:
self.close()
raise ValueError("The SAS7BDAT file appears to be truncated.")

self._page_length = self._read_int(const.page_size_offset + align1,
Expand Down Expand Up @@ -219,6 +228,7 @@ def _get_properties(self):
# Read a single float of the given width (4 or 8).
def _read_float(self, offset, width):
if width not in (4, 8):
self.close()
raise ValueError("invalid float width")
buf = self._read_bytes(offset, width)
fd = "f" if width == 4 else "d"
Expand All @@ -227,6 +237,7 @@ def _read_float(self, offset, width):
# Read a single signed integer of the given width (1, 2, 4 or 8).
def _read_int(self, offset, width):
if width not in (1, 2, 4, 8):
self.close()
raise ValueError("invalid int width")
buf = self._read_bytes(offset, width)
it = {1: "b", 2: "h", 4: "l", 8: "q"}[width]
Expand All @@ -238,11 +249,13 @@ def _read_bytes(self, offset, length):
self._path_or_buf.seek(offset)
buf = self._path_or_buf.read(length)
if len(buf) < length:
self.close()
msg = "Unable to read {:d} bytes from file position {:d}."
raise ValueError(msg.format(length, offset))
return buf
else:
if offset + length > len(self._cached_page):
self.close()
raise ValueError("The cached page is too small.")
return self._cached_page[offset:offset + length]

Expand All @@ -253,6 +266,7 @@ def _parse_metadata(self):
if len(self._cached_page) <= 0:
break
if len(self._cached_page) != self._page_length:
self.close()
raise ValueError(
"Failed to read a meta data page from the SAS file.")
done = self._process_page_meta()
Expand Down Expand Up @@ -302,6 +316,7 @@ def _get_subheader_index(self, signature, compression, ptype):
if (self.compression != "") and f1 and f2:
index = const.index.dataSubheaderIndex
else:
self.close()
raise ValueError("Unknown subheader signature")
return index

Expand Down Expand Up @@ -598,6 +613,7 @@ def _read_next_page(self):
if len(self._cached_page) <= 0:
return True
elif len(self._cached_page) != self._page_length:
self.close()
msg = ("failed to read complete page from file "
"(read {:d} of {:d} bytes)")
raise ValueError(msg.format(len(self._cached_page),
Expand Down Expand Up @@ -643,6 +659,7 @@ def _chunk_to_dataframe(self):
rslt.loc[ii, name] = np.nan
js += 1
else:
self.close()
raise ValueError("unknown column type %s" %
self.column_types[j])

Expand Down
9 changes: 9 additions & 0 deletions pandas/io/sas/sas_xport.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,9 @@ def __init__(self, filepath_or_buffer, index=None, encoding='ISO-8859-1',

self._read_header()

def close(self):
self.filepath_or_buffer.close()

def _get_row(self):
return self.filepath_or_buffer.read(80).decode()

Expand All @@ -262,13 +265,15 @@ def _read_header(self):
# read file header
line1 = self._get_row()
if line1 != _correct_line1:
self.close()
raise ValueError("Header record is not an XPORT file.")

line2 = self._get_row()
fif = [['prefix', 24], ['version', 8], ['OS', 8],
['_', 24], ['created', 16]]
file_info = _split_line(line2, fif)
if file_info['prefix'] != "SAS SAS SASLIB":
self.close()
raise ValueError("Header record has invalid prefix.")
file_info['created'] = _parse_date(file_info['created'])
self.file_info = file_info
Expand All @@ -282,6 +287,7 @@ def _read_header(self):
headflag1 = header1.startswith(_correct_header1)
headflag2 = (header2 == _correct_header2)
if not (headflag1 and headflag2):
self.close()
raise ValueError("Member header not found")
# usually 140, could be 135
fieldnamelength = int(header1[-5:-2])
Expand Down Expand Up @@ -321,6 +327,7 @@ def _read_header(self):
field['ntype'] = types[field['ntype']]
fl = field['field_length']
if field['ntype'] == 'numeric' and ((fl < 2) or (fl > 8)):
self.close()
msg = "Floating field width {0} is not between 2 and 8."
raise TypeError(msg.format(fl))

Expand All @@ -335,6 +342,7 @@ def _read_header(self):

header = self._get_row()
if not header == _correct_obs_header:
self.close()
raise ValueError("Observation header not found.")

self.fields = fields
Expand Down Expand Up @@ -425,6 +433,7 @@ def read(self, nrows=None):
read_lines = min(nrows, self.nobs - self._lines_read)
read_len = read_lines * self.record_length
if read_len <= 0:
self.close()
raise StopIteration
raw = self.filepath_or_buffer.read(read_len)
data = np.frombuffer(raw, dtype=self._dtype, count=read_lines)
Expand Down
4 changes: 3 additions & 1 deletion pandas/io/sas/sasreader.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,4 +58,6 @@ def read_sas(filepath_or_buffer, format=None, index=None, encoding=None,
if iterator or chunksize:
return reader

return reader.read()
data = reader.read()
reader.close()
return data
1 change: 1 addition & 0 deletions pandas/io/tests/sas/test_sas7bdat.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def test_encoding_options():
from pandas.io.sas.sas7bdat import SAS7BDATReader
rdr = SAS7BDATReader(fname, convert_header_text=False)
df3 = rdr.read()
rdr.close()
for x, y in zip(df1.columns, df3.columns):
assert(x == y.decode())

Expand Down
4 changes: 4 additions & 0 deletions pandas/io/tests/sas/test_xport.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,13 @@ def test1_basic(self):
# Test incremental read with `read` method.
reader = read_sas(self.file01, format="xport", iterator=True)
data = reader.read(10)
reader.close()
tm.assert_frame_equal(data, data_csv.iloc[0:10, :])

# Test incremental read with `get_chunk` method.
reader = read_sas(self.file01, format="xport", chunksize=10)
data = reader.get_chunk()
reader.close()
tm.assert_frame_equal(data, data_csv.iloc[0:10, :])

# Read full file with `read_sas` method
Expand All @@ -66,13 +68,15 @@ def test1_index(self):
reader = read_sas(self.file01, index="SEQN", format="xport",
iterator=True)
data = reader.read(10)
reader.close()
tm.assert_frame_equal(data, data_csv.iloc[0:10, :],
check_index_type=False)

# Test incremental read with `get_chunk` method.
reader = read_sas(self.file01, index="SEQN", format="xport",
chunksize=10)
data = reader.get_chunk()
reader.close()
tm.assert_frame_equal(data, data_csv.iloc[0:10, :],
check_index_type=False)

Expand Down

0 comments on commit 7aa5184

Please sign in to comment.