diff --git a/agate/table/__init__.py b/agate/table/__init__.py index 51afca46..22feef03 100644 --- a/agate/table/__init__.py +++ b/agate/table/__init__.py @@ -79,7 +79,13 @@ def __init__(self, rows, column_names=None, column_types=None, row_names=None, _ if column_names: self._column_names = utils.deduplicate(column_names, column_names=True) elif rows: - self._column_names = tuple(utils.letter_name(i) for i in range(len(rows[0]))) + try: + first_row = rows[0] + except TypeError: + # it's an iterator + first_row = next(rows) + rows = chain([first_row], rows) + self._column_names = tuple(utils.letter_name(i) for i in range(len(first_row))) warnings.warn('Column names not specified. "%s" will be used as names.' % str(self._column_names), RuntimeWarning, stacklevel=2) else: @@ -102,6 +108,9 @@ def __init__(self, rows, column_names=None, column_types=None, row_names=None, _ raise ValueError('Column types must be instances of DataType.') if isinstance(column_types, TypeTester): + # if rows is streaming, we must bring it in-memory + # note: we could do better here. + rows = tuple(rows) self._column_types = column_types.run(rows, self._column_names) else: self._column_types = tuple(column_types) diff --git a/agate/table/from_csv.py b/agate/table/from_csv.py index 3b1ac074..4dd9b6e3 100644 --- a/agate/table/from_csv.py +++ b/agate/table/from_csv.py @@ -59,14 +59,18 @@ def from_csv(cls, path, column_names=None, column_types=None, row_names=None, sk else: raise ValueError('skip_lines argument must be an int') - contents = StringIO(f.read()) + handle = f if sniff_limit is None: - kwargs['dialect'] = csv.Sniffer().sniff(contents.getvalue()) + # avoid reading the file twice + handle = StringIO(f.read()) + kwargs['dialect'] = csv.Sniffer().sniff(handle.getvalue()) elif sniff_limit > 0: - kwargs['dialect'] = csv.Sniffer().sniff(contents.getvalue()[:sniff_limit]) + kwargs['dialect'] = csv.Sniffer().sniff(f.read(sniff_limit)) + # return to the start of the file + f.seek(0) - reader = csv.reader(contents, header=header, **kwargs) + reader = csv.reader(handle, header=header, **kwargs) if header: if column_names is None: @@ -75,12 +79,12 @@ def from_csv(cls, path, column_names=None, column_types=None, row_names=None, sk next(reader) if row_limit is None: - rows = tuple(reader) + rows = reader else: - rows = tuple(itertools.islice(reader, row_limit)) + rows = itertools.islice(reader, row_limit) + + return Table(rows, column_names, column_types, row_names=row_names) finally: if close: f.close() - - return Table(rows, column_names, column_types, row_names=row_names)