Skip to content

Commit

Permalink
Merge pull request #778 from scottgigante/patch-1
Browse files Browse the repository at this point in the history
Avoid loading file into memory in read_csv
  • Loading branch information
jpmckinney committed Apr 27, 2024
2 parents 4001e97 + 24e8407 commit a1fe7e6
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 9 deletions.
11 changes: 10 additions & 1 deletion agate/table/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ def __init__(self, rows, column_names=None, column_types=None, row_names=None, _
if column_names:
self._column_names = utils.deduplicate(column_names, column_names=True)
elif rows:
self._column_names = tuple(utils.letter_name(i) for i in range(len(rows[0])))
try:
first_row = rows[0]
except TypeError:
# it's an iterator
first_row = next(rows)
rows = chain([first_row], rows)
self._column_names = tuple(utils.letter_name(i) for i in range(len(first_row)))
warnings.warn('Column names not specified. "%s" will be used as names.' % str(self._column_names),
RuntimeWarning, stacklevel=2)
else:
Expand All @@ -102,6 +108,9 @@ def __init__(self, rows, column_names=None, column_types=None, row_names=None, _
raise ValueError('Column types must be instances of DataType.')

if isinstance(column_types, TypeTester):
# if rows is streaming, we must bring it in-memory
# note: we could do better here.
rows = tuple(rows)
self._column_types = column_types.run(rows, self._column_names)
else:
self._column_types = tuple(column_types)
Expand Down
20 changes: 12 additions & 8 deletions agate/table/from_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,14 +59,18 @@ def from_csv(cls, path, column_names=None, column_types=None, row_names=None, sk
else:
raise ValueError('skip_lines argument must be an int')

contents = StringIO(f.read())
handle = f

if sniff_limit is None:
kwargs['dialect'] = csv.Sniffer().sniff(contents.getvalue())
# avoid reading the file twice
handle = StringIO(f.read())
kwargs['dialect'] = csv.Sniffer().sniff(handle.getvalue())
elif sniff_limit > 0:
kwargs['dialect'] = csv.Sniffer().sniff(contents.getvalue()[:sniff_limit])
kwargs['dialect'] = csv.Sniffer().sniff(f.read(sniff_limit))
# return to the start of the file
f.seek(0)

reader = csv.reader(contents, header=header, **kwargs)
reader = csv.reader(handle, header=header, **kwargs)

if header:
if column_names is None:
Expand All @@ -75,12 +79,12 @@ def from_csv(cls, path, column_names=None, column_types=None, row_names=None, sk
next(reader)

if row_limit is None:
rows = tuple(reader)
rows = reader
else:
rows = tuple(itertools.islice(reader, row_limit))
rows = itertools.islice(reader, row_limit)

return Table(rows, column_names, column_types, row_names=row_names)

finally:
if close:
f.close()

return Table(rows, column_names, column_types, row_names=row_names)

0 comments on commit a1fe7e6

Please sign in to comment.