Skip to content

Commit

Permalink
Merge pull request #6553 from bashtage/time_stamp-data_label-reserved…
Browse files Browse the repository at this point in the history
…_words

ENH: Allow timestamp and data label to be set when exporting to Stata
  • Loading branch information
jreback committed Mar 7, 2014
2 parents 3590d8c + 0638be8 commit 7d49037
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 17 deletions.
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Improvements to existing features
- perf improvements in DataFrame construction with certain offsets, by removing faulty caching
(e.g. MonthEnd,BusinessMonthEnd), (:issue:`6479`)
- perf improvements in single-dtyped indexing (:issue:`6484`)
- ``StataWriter`` and ``DataFrame.to_stata`` accept time stamp and data labels (:issue:`6545`)

.. _release.bug_fixes-0.14.0:

Expand Down
3 changes: 3 additions & 0 deletions doc/source/v0.14.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ Enhancements
- ``DataFrame.to_stata`` will now check data for compatibility with Stata data types
and will upcast when needed. When it isn't possibly to losslessly upcast, a warning
is raised (:issue:`6327`)
- ``DataFrame.to_stata`` and ``StataWriter`` will accept keyword arguments time_stamp
and data_label which allow the time stamp and dataset label to be set when creating a
file. (:issue:`6545`)

Performance
~~~~~~~~~~~
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='',

def to_stata(
self, fname, convert_dates=None, write_index=True, encoding="latin-1",
byteorder=None):
byteorder=None, time_stamp=None, data_label=None):
"""
A class for writing Stata binary dta files from array-like objects
Expand Down Expand Up @@ -1247,7 +1247,8 @@ def to_stata(
"""
from pandas.io.stata import StataWriter
writer = StataWriter(fname, self, convert_dates=convert_dates,
encoding=encoding, byteorder=byteorder)
encoding=encoding, byteorder=byteorder,
time_stamp=time_stamp, data_label=data_label)
writer.write_file()

def to_sql(self, name, con, flavor='sqlite', if_exists='fail', **kwargs):
Expand Down
43 changes: 33 additions & 10 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def __init__(self, encoding):
'd': np.float64(struct.unpack('<d', b'\x00\x00\x00\x00\x00\x00\xe0\x7f')[0])
}

# Reserved words cannot be used as variable names
self.RESERVED_WORDS = ('aggregate', 'array', 'boolean', 'break',
'byte', 'case', 'catch', 'class', 'colvector',
'complex', 'const', 'continue', 'default',
'delegate', 'delete', 'do', 'double', 'else',
'eltypedef', 'end', 'enum', 'explicit',
'export', 'external', 'float', 'for', 'friend',
'function', 'global', 'goto', 'if', 'inline',
'int', 'local', 'long', 'NULL', 'pragma',
'protected', 'quad', 'rowvector', 'short',
'typedef', 'typename', 'virtual')

def _decode_bytes(self, str, errors=None):
if compat.PY3 or self._encoding is not None:
return str.decode(self._encoding, errors)
Expand Down Expand Up @@ -449,10 +461,10 @@ def _read_header(self):
self.path_or_buf.read(4))[0]
self.path_or_buf.read(11) # </N><label>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.data_label = self.path_or_buf.read(strlen)
self.data_label = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(19) # </label><timestamp>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.time_stamp = self.path_or_buf.read(strlen)
self.time_stamp = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(26) # </timestamp></header><map>
self.path_or_buf.read(8) # 0x0000000000000000
self.path_or_buf.read(8) # position of <map>
Expand Down Expand Up @@ -543,11 +555,11 @@ def _read_header(self):
self.nobs = struct.unpack(self.byteorder + 'I',
self.path_or_buf.read(4))[0]
if self.format_version > 105:
self.data_label = self.path_or_buf.read(81)
self.data_label = self._null_terminate(self.path_or_buf.read(81))
else:
self.data_label = self.path_or_buf.read(32)
self.data_label = self._null_terminate(self.path_or_buf.read(32))
if self.format_version > 104:
self.time_stamp = self.path_or_buf.read(18)
self.time_stamp = self._null_terminate(self.path_or_buf.read(18))

# descriptors
if self.format_version > 108:
Expand Down Expand Up @@ -1029,6 +1041,11 @@ class StataWriter(StataParser):
byteorder : str
Can be ">", "<", "little", or "big". The default is None which uses
`sys.byteorder`
time_stamp : datetime
A date time to use when writing the file. Can be None, in which
case the current time is used.
dataset_label : str
A label for the data set. Should be 80 characters or smaller.
Returns
-------
Expand All @@ -1047,10 +1064,13 @@ class StataWriter(StataParser):
>>> writer.write_file()
"""
def __init__(self, fname, data, convert_dates=None, write_index=True,
encoding="latin-1", byteorder=None):
encoding="latin-1", byteorder=None, time_stamp=None,
data_label=None):
super(StataWriter, self).__init__(encoding)
self._convert_dates = convert_dates
self._write_index = write_index
self._time_stamp = time_stamp
self._data_label = data_label
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)

Expand Down Expand Up @@ -1086,7 +1106,7 @@ def __iter__(self):

if self._write_index:
data = data.reset_index()
# Check columns for compatbaility with stata
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)
self.datarows = DataFrameRowIter(data)
self.nobs, self.nvar = data.shape
Expand All @@ -1110,7 +1130,8 @@ def __iter__(self):
self.fmtlist[key] = self._convert_dates[key]

def write_file(self):
self._write_header()
self._write_header(time_stamp=self._time_stamp,
data_label=self._data_label)
self._write_descriptors()
self._write_variable_labels()
# write 5 zeros for expansion fields
Expand Down Expand Up @@ -1147,7 +1168,7 @@ def _write_header(self, data_label=None, time_stamp=None):
# format dd Mon yyyy hh:mm
if time_stamp is None:
time_stamp = datetime.datetime.now()
elif not isinstance(time_stamp, datetime):
elif not isinstance(time_stamp, datetime.datetime):
raise ValueError("time_stamp should be datetime type")
self._file.write(
self._null_terminate(time_stamp.strftime("%d %b %Y %H:%M"))
Expand All @@ -1169,7 +1190,9 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')

# Variable name must not be a reserved word
if name in self.RESERVED_WORDS:
name = '_' + name
# Variable name may not start with a number
if name[0] > '0' and name[0] < '9':
name = '_' + name
Expand Down
32 changes: 27 additions & 5 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=E1101

from datetime import datetime
import datetime as dt
import os
import warnings
import nose
Expand Down Expand Up @@ -248,7 +249,7 @@ def test_read_write_dta10(self):

original = DataFrame(data=[["string", "object", 1, 1.1,
np.datetime64('2003-12-25')]],
columns=['string', 'object', 'integer', 'float',
columns=['string', 'object', 'integer', 'floating',
'datetime'])
original["object"] = Series(original["object"], dtype=object)
original.index.name = 'index'
Expand Down Expand Up @@ -304,10 +305,20 @@ def test_read_write_dta11(self):
def test_read_write_dta12(self):
# skip_if_not_little_endian()

original = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_1', 'astringwithmorethan32characters_2', '+', '-'])
formatted = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_', '_0astringwithmorethan32character', '_', '_1_'])
original = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_1',
'astringwithmorethan32characters_2',
'+',
'-',
'short',
'delete'])
formatted = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_',
'_0astringwithmorethan32character',
'_',
'_1_',
'_short',
'_delete'])
formatted.index.name = 'index'
formatted = formatted.astype(np.int32)

Expand Down Expand Up @@ -376,6 +387,17 @@ def test_read_write_reread_dta15(self):
tm.assert_frame_equal(parsed_113, parsed_114)
tm.assert_frame_equal(parsed_114, parsed_115)

def test_timestamp_and_label(self):
original = DataFrame([(1,)], columns=['var'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
reader = StataReader(path)
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label



if __name__ == '__main__':
Expand Down

0 comments on commit 7d49037

Please sign in to comment.