Skip to content

Commit

Permalink
ENH: Allow timestamp and data label to be set when exporting to Stata
Browse files Browse the repository at this point in the history
Added code which allows the time stamp and the data label to be set using
either StataWriter or to_stata.  Also simplified reading these values using
StataReader by removing null bytes from the string values read.

Added basic test for both.

Also fixed one small bug where variables could be stored using Stata reserved
words.
  • Loading branch information
bashtage committed Mar 6, 2014
1 parent 170377d commit 0638be8
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 17 deletions.
1 change: 1 addition & 0 deletions doc/source/release.rst
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ Improvements to existing features
- perf improvements in DataFrame construction with certain offsets, by removing faulty caching
(e.g. MonthEnd,BusinessMonthEnd), (:issue:`6479`)
- perf improvements in single-dtyped indexing (:issue:`6484`)
- ``StataWriter`` and ``DataFrame.to_stata`` accept time stamp and data labels (:issue:`6545`)

.. _release.bug_fixes-0.14.0:

Expand Down
3 changes: 3 additions & 0 deletions doc/source/v0.14.0.txt
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,9 @@ Enhancements
- ``DataFrame.to_stata`` will now check data for compatibility with Stata data types
and will upcast when needed. When it isn't possibly to losslessly upcast, a warning
is raised (:issue:`6327`)
- ``DataFrame.to_stata`` and ``StataWriter`` will accept keyword arguments time_stamp
and data_label which allow the time stamp and dataset label to be set when creating a
file. (:issue:`6545`)

Performance
~~~~~~~~~~~
Expand Down
5 changes: 3 additions & 2 deletions pandas/core/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -1216,7 +1216,7 @@ def to_excel(self, excel_writer, sheet_name='Sheet1', na_rep='',

def to_stata(
self, fname, convert_dates=None, write_index=True, encoding="latin-1",
byteorder=None):
byteorder=None, time_stamp=None, data_label=None):
"""
A class for writing Stata binary dta files from array-like objects
Expand Down Expand Up @@ -1247,7 +1247,8 @@ def to_stata(
"""
from pandas.io.stata import StataWriter
writer = StataWriter(fname, self, convert_dates=convert_dates,
encoding=encoding, byteorder=byteorder)
encoding=encoding, byteorder=byteorder,
time_stamp=time_stamp, data_label=data_label)
writer.write_file()

def to_sql(self, name, con, flavor='sqlite', if_exists='fail', **kwargs):
Expand Down
43 changes: 33 additions & 10 deletions pandas/io/stata.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,18 @@ def __init__(self, encoding):
'd': np.float64(struct.unpack('<d', b'\x00\x00\x00\x00\x00\x00\xe0\x7f')[0])
}

# Reserved words cannot be used as variable names
self.RESERVED_WORDS = ('aggregate', 'array', 'boolean', 'break',
'byte', 'case', 'catch', 'class', 'colvector',
'complex', 'const', 'continue', 'default',
'delegate', 'delete', 'do', 'double', 'else',
'eltypedef', 'end', 'enum', 'explicit',
'export', 'external', 'float', 'for', 'friend',
'function', 'global', 'goto', 'if', 'inline',
'int', 'local', 'long', 'NULL', 'pragma',
'protected', 'quad', 'rowvector', 'short',
'typedef', 'typename', 'virtual')

def _decode_bytes(self, str, errors=None):
if compat.PY3 or self._encoding is not None:
return str.decode(self._encoding, errors)
Expand Down Expand Up @@ -449,10 +461,10 @@ def _read_header(self):
self.path_or_buf.read(4))[0]
self.path_or_buf.read(11) # </N><label>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.data_label = self.path_or_buf.read(strlen)
self.data_label = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(19) # </label><timestamp>
strlen = struct.unpack('b', self.path_or_buf.read(1))[0]
self.time_stamp = self.path_or_buf.read(strlen)
self.time_stamp = self._null_terminate(self.path_or_buf.read(strlen))
self.path_or_buf.read(26) # </timestamp></header><map>
self.path_or_buf.read(8) # 0x0000000000000000
self.path_or_buf.read(8) # position of <map>
Expand Down Expand Up @@ -543,11 +555,11 @@ def _read_header(self):
self.nobs = struct.unpack(self.byteorder + 'I',
self.path_or_buf.read(4))[0]
if self.format_version > 105:
self.data_label = self.path_or_buf.read(81)
self.data_label = self._null_terminate(self.path_or_buf.read(81))
else:
self.data_label = self.path_or_buf.read(32)
self.data_label = self._null_terminate(self.path_or_buf.read(32))
if self.format_version > 104:
self.time_stamp = self.path_or_buf.read(18)
self.time_stamp = self._null_terminate(self.path_or_buf.read(18))

# descriptors
if self.format_version > 108:
Expand Down Expand Up @@ -1029,6 +1041,11 @@ class StataWriter(StataParser):
byteorder : str
Can be ">", "<", "little", or "big". The default is None which uses
`sys.byteorder`
time_stamp : datetime
A date time to use when writing the file. Can be None, in which
case the current time is used.
dataset_label : str
A label for the data set. Should be 80 characters or smaller.
Returns
-------
Expand All @@ -1047,10 +1064,13 @@ class StataWriter(StataParser):
>>> writer.write_file()
"""
def __init__(self, fname, data, convert_dates=None, write_index=True,
encoding="latin-1", byteorder=None):
encoding="latin-1", byteorder=None, time_stamp=None,
data_label=None):
super(StataWriter, self).__init__(encoding)
self._convert_dates = convert_dates
self._write_index = write_index
self._time_stamp = time_stamp
self._data_label = data_label
# attach nobs, nvars, data, varlist, typlist
self._prepare_pandas(data)

Expand Down Expand Up @@ -1086,7 +1106,7 @@ def __iter__(self):

if self._write_index:
data = data.reset_index()
# Check columns for compatbaility with stata
# Check columns for compatibility with stata
data = _cast_to_stata_types(data)
self.datarows = DataFrameRowIter(data)
self.nobs, self.nvar = data.shape
Expand All @@ -1110,7 +1130,8 @@ def __iter__(self):
self.fmtlist[key] = self._convert_dates[key]

def write_file(self):
self._write_header()
self._write_header(time_stamp=self._time_stamp,
data_label=self._data_label)
self._write_descriptors()
self._write_variable_labels()
# write 5 zeros for expansion fields
Expand Down Expand Up @@ -1147,7 +1168,7 @@ def _write_header(self, data_label=None, time_stamp=None):
# format dd Mon yyyy hh:mm
if time_stamp is None:
time_stamp = datetime.datetime.now()
elif not isinstance(time_stamp, datetime):
elif not isinstance(time_stamp, datetime.datetime):
raise ValueError("time_stamp should be datetime type")
self._file.write(
self._null_terminate(time_stamp.strftime("%d %b %Y %H:%M"))
Expand All @@ -1169,7 +1190,9 @@ def _write_descriptors(self, typlist=None, varlist=None, srtlist=None,
for c in name:
if (c < 'A' or c > 'Z') and (c < 'a' or c > 'z') and (c < '0' or c > '9') and c != '_':
name = name.replace(c, '_')

# Variable name must not be a reserved word
if name in self.RESERVED_WORDS:
name = '_' + name
# Variable name may not start with a number
if name[0] > '0' and name[0] < '9':
name = '_' + name
Expand Down
32 changes: 27 additions & 5 deletions pandas/io/tests/test_stata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# pylint: disable=E1101

from datetime import datetime
import datetime as dt
import os
import warnings
import nose
Expand Down Expand Up @@ -248,7 +249,7 @@ def test_read_write_dta10(self):

original = DataFrame(data=[["string", "object", 1, 1.1,
np.datetime64('2003-12-25')]],
columns=['string', 'object', 'integer', 'float',
columns=['string', 'object', 'integer', 'floating',
'datetime'])
original["object"] = Series(original["object"], dtype=object)
original.index.name = 'index'
Expand Down Expand Up @@ -304,10 +305,20 @@ def test_read_write_dta11(self):
def test_read_write_dta12(self):
# skip_if_not_little_endian()

original = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_1', 'astringwithmorethan32characters_2', '+', '-'])
formatted = DataFrame([(1, 2, 3, 4)],
columns=['astringwithmorethan32characters_', '_0astringwithmorethan32character', '_', '_1_'])
original = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_1',
'astringwithmorethan32characters_2',
'+',
'-',
'short',
'delete'])
formatted = DataFrame([(1, 2, 3, 4, 5, 6)],
columns=['astringwithmorethan32characters_',
'_0astringwithmorethan32character',
'_',
'_1_',
'_short',
'_delete'])
formatted.index.name = 'index'
formatted = formatted.astype(np.int32)

Expand Down Expand Up @@ -376,6 +387,17 @@ def test_read_write_reread_dta15(self):
tm.assert_frame_equal(parsed_113, parsed_114)
tm.assert_frame_equal(parsed_114, parsed_115)

def test_timestamp_and_label(self):
original = DataFrame([(1,)], columns=['var'])
time_stamp = datetime(2000, 2, 29, 14, 21)
data_label = 'This is a data file.'
with tm.ensure_clean() as path:
original.to_stata(path, time_stamp=time_stamp, data_label=data_label)
reader = StataReader(path)
parsed_time_stamp = dt.datetime.strptime(reader.time_stamp, ('%d %b %Y %H:%M'))
assert parsed_time_stamp == time_stamp
assert reader.data_label == data_label



if __name__ == '__main__':
Expand Down

0 comments on commit 0638be8

Please sign in to comment.