diff --git a/c_glib/README.md b/c_glib/README.md index 1f67d7ea52ccb..d801fc83b5324 100644 --- a/c_glib/README.md +++ b/c_glib/README.md @@ -53,6 +53,8 @@ recommended that you use packages. Note that the packages are "unofficial". "Official" packages will be released in the future. +If you find problems when installing please see [common build problems](https://github.com/apache/arrow/blob/master/c_glib/README.md#common-build-problems). + ### Package See [install document](../site/install.md) for details. diff --git a/c_glib/configure.ac b/c_glib/configure.ac index eabe7bad51227..f4f2c99bbc39e 100644 --- a/c_glib/configure.ac +++ b/c_glib/configure.ac @@ -143,7 +143,7 @@ AC_CONFIG_FILES([ arrow-gpu-glib/arrow-gpu-glib.pc doc/Makefile doc/reference/Makefile - doc/reference/xml/Makefile + doc/reference/entities.xml example/Makefile example/lua/Makefile tool/Makefile diff --git a/c_glib/doc/reference/Makefile.am b/c_glib/doc/reference/Makefile.am index 4c005c237b300..454c2b0692da6 100644 --- a/c_glib/doc/reference/Makefile.am +++ b/c_glib/doc/reference/Makefile.am @@ -15,9 +15,6 @@ # specific language governing permissions and limitations # under the License. -SUBDIRS = \ - xml - DOC_MODULE = arrow-glib DOC_MAIN_SGML_FILE = $(DOC_MODULE)-docs.xml @@ -72,4 +69,5 @@ CLEANFILES += \ $(DOC_MODULE).types EXTRA_DIST += \ + entities.xml.in \ meson.build diff --git a/c_glib/doc/reference/arrow-glib-docs.xml b/c_glib/doc/reference/arrow-glib-docs.xml index 51e7b2a6a6cf5..23d1e9a0f271a 100644 --- a/c_glib/doc/reference/arrow-glib-docs.xml +++ b/c_glib/doc/reference/arrow-glib-docs.xml @@ -21,10 +21,10 @@ "http://www.oasis-open.org/docbook/xml/4.3/docbookx.dtd" [ - + %gtkdocentities; ]> - + &package_name; Reference Manual diff --git a/c_glib/doc/reference/xml/gtkdocentities.ent.in b/c_glib/doc/reference/entities.xml.in similarity index 76% rename from c_glib/doc/reference/xml/gtkdocentities.ent.in rename to c_glib/doc/reference/entities.xml.in index dc0cf1a0d8d4a..aa5addb4e8431 100644 --- a/c_glib/doc/reference/xml/gtkdocentities.ent.in +++ b/c_glib/doc/reference/entities.xml.in @@ -16,9 +16,9 @@ specific language governing permissions and limitations under the License. --> - - - - - - + + + + + + diff --git a/c_glib/doc/reference/meson.build b/c_glib/doc/reference/meson.build index 3374fbde5b9ed..431aa0a5c82a1 100644 --- a/c_glib/doc/reference/meson.build +++ b/c_glib/doc/reference/meson.build @@ -17,7 +17,18 @@ # specific language governing permissions and limitations # under the License. -subdir('xml') +entities_conf = configuration_data() +entities_conf.set('PACKAGE', meson.project_name()) +entities_conf.set('PACKAGE_BUGREPORT', + 'https://issues.apache.org/jira/browse/ARROW') +entities_conf.set('PACKAGE_NAME', meson.project_name()) +entities_conf.set('PACKAGE_STRING', + ' '.join([meson.project_name(), version])) +entities_conf.set('PACKAGE_URL', 'https://arrow.apache.org/') +entities_conf.set('PACKAGE_VERSION', version) +configure_file(input: 'entities.xml.in', + output: 'entities.xml', + configuration: entities_conf) private_headers = [ ] diff --git a/ci/msvc-build.bat b/ci/msvc-build.bat index 62ebcf364e77b..58dfc2a146572 100644 --- a/ci/msvc-build.bat +++ b/ci/msvc-build.bat @@ -81,7 +81,7 @@ conda info -a conda create -n arrow -q -y python=%PYTHON% ^ six pytest setuptools numpy pandas cython ^ - thrift-cpp + thrift-cpp=0.11.0 if "%JOB%" == "Toolchain" ( @@ -145,6 +145,6 @@ pushd python set PYARROW_CXXFLAGS=/WX python setup.py build_ext --inplace --with-parquet --bundle-arrow-cpp bdist_wheel || exit /B -py.test pyarrow -v -s --parquet || exit /B +py.test pyarrow -r sxX --durations=15 -v -s --parquet || exit /B popd diff --git a/ci/travis_before_script_cpp.sh b/ci/travis_before_script_cpp.sh index fd2c1644638c4..7c1d726d4d37e 100755 --- a/ci/travis_before_script_cpp.sh +++ b/ci/travis_before_script_cpp.sh @@ -47,7 +47,7 @@ if [ "$ARROW_TRAVIS_USE_TOOLCHAIN" == "1" ]; then zlib \ cmake \ curl \ - thrift-cpp \ + thrift-cpp=0.11.0 \ ninja # HACK(wesm): We started experiencing OpenSSL failures when Miniconda was diff --git a/ci/travis_lint.sh b/ci/travis_lint.sh index e234b7b015b8d..6a2a0be18cf9f 100755 --- a/ci/travis_lint.sh +++ b/ci/travis_lint.sh @@ -35,10 +35,10 @@ popd # Fail fast on style checks sudo pip install flake8 -PYARROW_DIR=$TRAVIS_BUILD_DIR/python/pyarrow +PYTHON_DIR=$TRAVIS_BUILD_DIR/python -flake8 --count $PYARROW_DIR +flake8 --count $PYTHON_DIR/pyarrow # Check Cython files with some checks turned off flake8 --count --config=$PYTHON_DIR/.flake8.cython \ - $PYARROW_DIR + $PYTHON_DIR/pyarrow diff --git a/ci/travis_script_python.sh b/ci/travis_script_python.sh index 9e74906d03739..7c896df9c840f 100755 --- a/ci/travis_script_python.sh +++ b/ci/travis_script_python.sh @@ -96,7 +96,7 @@ if [ $TRAVIS_OS_NAME == "linux" ]; then fi PYARROW_PATH=$CONDA_PREFIX/lib/python$PYTHON_VERSION/site-packages/pyarrow -python -m pytest -vv -r sxX -s $PYARROW_PATH --parquet +python -m pytest -vv -r sxX --durations=15 -s $PYARROW_PATH --parquet if [ "$PYTHON_VERSION" == "3.6" ] && [ $TRAVIS_OS_NAME == "linux" ]; then # Build documentation once diff --git a/cpp/README.md b/cpp/README.md index 39a1ccac64818..52169974de41e 100644 --- a/cpp/README.md +++ b/cpp/README.md @@ -39,9 +39,11 @@ sudo apt-get install cmake \ libboost-system-dev ``` -On OS X, you can use [Homebrew][1]: +On macOS, you can use [Homebrew][1]: ```shell +git clone https://github.com/apache/arrow.git +cd arrow brew update && brew bundle --file=c_glib/Brewfile ``` @@ -250,6 +252,21 @@ Logging IWYU to /tmp/arrow-cpp-iwyu.gT7XXV ... ``` +### Linting + +We require that you follow a certain coding style in the C++ code base. +You can check your code abides by that coding style by running: + + make lint + +You can also fix any formatting errors automatically: + + make format + +These commands require `clang-format-4.0` (and not any other version). +You may find the required packages at http://releases.llvm.org/download.html +or use the Debian/Ubuntu APT repositories on https://apt.llvm.org/. + ## Continuous Integration Pull requests are run through travis-ci for continuous integration. You can avoid diff --git a/cpp/apidoc/HDFS.md b/cpp/apidoc/HDFS.md index d54ad270c05f4..d3671fb7691ba 100644 --- a/cpp/apidoc/HDFS.md +++ b/cpp/apidoc/HDFS.md @@ -50,6 +50,10 @@ export CLASSPATH=`$HADOOP_HOME/bin/hadoop classpath --glob` * `ARROW_LIBHDFS_DIR` (optional): explicit location of `libhdfs.so` if it is installed somewhere other than `$HADOOP_HOME/lib/native`. +To accommodate distribution-specific nuances, the `JAVA_HOME` variable may be +set to the root path for the Java SDK, the JRE path itself, or to the directory +containing the `libjvm` library. + ### Mac Specifics The installed location of Java on OS X can vary, however the following snippet diff --git a/cpp/build-support/cpplint.py b/cpp/build-support/cpplint.py index ccc25d4c56b1a..95c0c32595d81 100755 --- a/cpp/build-support/cpplint.py +++ b/cpp/build-support/cpplint.py @@ -44,6 +44,8 @@ import codecs import copy import getopt +import glob +import itertools import math # for log import os import re @@ -51,16 +53,47 @@ import string import sys import unicodedata +import xml.etree.ElementTree + +# if empty, use defaults +_header_extensions = set([]) + +# if empty, use defaults +_valid_extensions = set([]) + + +# Files with any of these extensions are considered to be +# header files (and will undergo different style checks). +# This set can be extended by using the --headers +# option (also supported in CPPLINT.cfg) +def GetHeaderExtensions(): + if not _header_extensions: + return set(['h', 'hpp', 'hxx', 'h++', 'cuh']) + return _header_extensions + +# The allowed extensions for file names +# This is set by --extensions flag +def GetAllExtensions(): + if not _valid_extensions: + return GetHeaderExtensions().union(set(['c', 'cc', 'cpp', 'cxx', 'c++', 'cu'])) + return _valid_extensions + +def GetNonHeaderExtensions(): + return GetAllExtensions().difference(GetHeaderExtensions()) _USAGE = """ -Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...] - [--counting=total|toplevel|detailed] [--root=subdir] - [--linelength=digits] +Syntax: cpplint.py [--verbose=#] [--output=emacs|eclipse|vs7|junit] + [--filter=-x,+y,...] + [--counting=total|toplevel|detailed] [--repository=path] + [--root=subdir] [--linelength=digits] [--recursive] + [--exclude=path] + [--headers=ext1,ext2] + [--extensions=hpp,cpp,...] [file] ... The style guidelines this tries to follow are those in - http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml + https://google.github.io/styleguide/cppguide.html Every problem is given a confidence score from 1-5, with 5 meaning we are certain of the problem, and 1 meaning it could be a legitimate construct. @@ -71,17 +104,26 @@ suppresses errors of all categories on that line. The files passed in will be linted; at least one file must be provided. - Default linted extensions are .cc, .cpp, .cu, .cuh and .h. Change the - extensions with the --extensions flag. + Default linted extensions are %s. + Other file types will be ignored. + Change the extensions with the --extensions flag. Flags: - output=vs7 - By default, the output is formatted to ease emacs parsing. Visual Studio - compatible output (vs7) may also be used. Other formats are unsupported. + output=emacs|eclipse|vs7|junit + By default, the output is formatted to ease emacs parsing. Output + compatible with eclipse (eclipse), Visual Studio (vs7), and JUnit + XML parsers such as those used in Jenkins and Bamboo may also be + used. Other formats are unsupported. verbose=# Specify a number 0-5 to restrict errors to certain verbosity levels. + Errors with lower verbosity levels have lower confidence and are more + likely to be false positives. + + quiet + Supress output other than linting errors, such as information about + which files have been processed and excluded. filter=-x,+y,... Specify a comma-separated list of category-filters to apply: only @@ -105,17 +147,40 @@ also be printed. If 'detailed' is provided, then a count is provided for each category like 'build/class'. + repository=path + The top level directory of the repository, used to derive the header + guard CPP variable. By default, this is determined by searching for a + path that contains .git, .hg, or .svn. When this flag is specified, the + given path is used instead. This option allows the header guard CPP + variable to remain consistent even if members of a team have different + repository root directories (such as when checking out a subdirectory + with SVN). In addition, users of non-mainstream version control systems + can use this flag to ensure readable header guard CPP variables. + + Examples: + Assuming that Alice checks out ProjectName and Bob checks out + ProjectName/trunk and trunk contains src/chrome/ui/browser.h, then + with no --repository flag, the header guard CPP variable will be: + + Alice => TRUNK_SRC_CHROME_BROWSER_UI_BROWSER_H_ + Bob => SRC_CHROME_BROWSER_UI_BROWSER_H_ + + If Alice uses the --repository=trunk flag and Bob omits the flag or + uses --repository=. then the header guard CPP variable will be: + + Alice => SRC_CHROME_BROWSER_UI_BROWSER_H_ + Bob => SRC_CHROME_BROWSER_UI_BROWSER_H_ + root=subdir - The root directory used for deriving header guard CPP variable. - By default, the header guard CPP variable is calculated as the relative - path to the directory that contains .git, .hg, or .svn. When this flag - is specified, the relative path is calculated from the specified - directory. If the specified directory does not exist, this flag is - ignored. + The root directory used for deriving header guard CPP variables. This + directory is relative to the top level directory of the repository which + by default is determined by searching for a directory that contains .git, + .hg, or .svn but can also be controlled with the --repository flag. If + the specified directory does not exist, this flag is ignored. Examples: - Assuming that src/.git exists, the header guard CPP variables for - src/chrome/browser/ui/browser.h are: + Assuming that src is the top level directory of the repository, the + header guard CPP variables for src/chrome/browser/ui/browser.h are: No flag => CHROME_BROWSER_UI_BROWSER_H_ --root=chrome => BROWSER_UI_BROWSER_H_ @@ -128,11 +193,36 @@ Examples: --linelength=120 + recursive + Search for files to lint recursively. Each directory given in the list + of files to be linted is replaced by all files that descend from that + directory. Files with extensions not in the valid extensions list are + excluded. + + exclude=path + Exclude the given path from the list of files to be linted. Relative + paths are evaluated relative to the current directory and shell globbing + is performed. This flag can be provided multiple times to exclude + multiple files. + + Examples: + --exclude=one.cc + --exclude=src/*.cc + --exclude=src/*.cc --exclude=test/*.cc + extensions=extension,extension,... The allowed file extensions that cpplint will check Examples: - --extensions=hpp,cpp + --extensions=%s + + headers=extension,extension,... + The allowed header extensions that cpplint will consider to be header files + (by default, only files with extensions %s + will be assumed to be headers) + + Examples: + --headers=%s cpplint.py supports per-directory configurations specified in CPPLINT.cfg files. CPPLINT.cfg file can contain a number of key=value pairs. @@ -142,6 +232,7 @@ filter=+filter1,-filter2,... exclude_files=regex linelength=80 + root=subdir "set noparent" option prevents cpplint from traversing directory tree upwards looking for more .cfg files in parent directories. This option @@ -153,22 +244,28 @@ "exclude_files" allows to specify a regular expression to be matched against a file name. If the expression matches, the file is skipped and not run - through liner. + through the linter. + + "linelength" specifies the allowed line length for the project. - "linelength" allows to specify the allowed line length for the project. + The "root" option is similar in function to the --root flag (see example + above). CPPLINT.cfg has an effect on files in the same directory and all - sub-directories, unless overridden by a nested configuration file. + subdirectories, unless overridden by a nested configuration file. Example file: filter=-build/include_order,+build/include_alpha - exclude_files=.*\.cc + exclude_files=.*\\.cc The above example disables build/include_order warning and enables build/include_alpha as well as excludes all .cc from being processed by linter, in the current directory (where the .cfg - file is located) and all sub-directories. -""" + file is located) and all subdirectories. +""" % (list(GetAllExtensions()), + ','.join(list(GetAllExtensions())), + GetHeaderExtensions(), + ','.join(GetHeaderExtensions())) # We categorize each error message we print. Here are the categories. # We want an explicit list so we can list them all in cpplint --filter=. @@ -177,15 +274,19 @@ _ERROR_CATEGORIES = [ 'build/class', 'build/c++11', + 'build/c++14', + 'build/c++tr1', 'build/deprecated', 'build/endif_comment', 'build/explicit_make_pair', 'build/forward_decl', 'build/header_guard', 'build/include', + 'build/include_subdir', 'build/include_alpha', 'build/include_order', 'build/include_what_you_use', + 'build/namespaces_literals', 'build/namespaces', 'build/printf_format', 'build/storage_class', @@ -196,7 +297,6 @@ 'readability/check', 'readability/constructors', 'readability/fn_size', - 'readability/function', 'readability/inheritance', 'readability/multiline_comment', 'readability/multiline_string', @@ -227,6 +327,7 @@ 'whitespace/comma', 'whitespace/comments', 'whitespace/empty_conditional_body', + 'whitespace/empty_if_body', 'whitespace/empty_loop_body', 'whitespace/end_of_line', 'whitespace/ending_newline', @@ -245,6 +346,7 @@ # compatibility they may still appear in NOLINT comments. _LEGACY_ERROR_CATEGORIES = [ 'readability/streams', + 'readability/function', ] # The default state of the category filter. This is overridden by the --filter= @@ -253,6 +355,16 @@ # All entries here should start with a '-' or '+', as in the --filter= flag. _DEFAULT_FILTERS = ['-build/include_alpha'] +# The default list of categories suppressed for C (not C++) files. +_DEFAULT_C_SUPPRESSED_CATEGORIES = [ + 'readability/casting', + ] + +# The default list of categories suppressed for Linux Kernel files. +_DEFAULT_KERNEL_SUPPRESSED_CATEGORIES = [ + 'whitespace/tab', + ] + # We used to check for high-bit characters, but after much discussion we # decided those were OK, as long as they were in UTF-8 and didn't represent # hard-coded international strings, which belong in a separate i18n file. @@ -346,6 +458,7 @@ 'random', 'ratio', 'regex', + 'scoped_allocator', 'set', 'sstream', 'stack', @@ -393,6 +506,19 @@ 'cwctype', ]) +# Type names +_TYPES = re.compile( + r'^(?:' + # [dcl.type.simple] + r'(char(16_t|32_t)?)|wchar_t|' + r'bool|short|int|long|signed|unsigned|float|double|' + # [support.types] + r'(ptrdiff_t|size_t|max_align_t|nullptr_t)|' + # [cstdint.syn] + r'(u?int(_fast|_least)?(8|16|32|64)_t)|' + r'(u?int(max|ptr)_t)|' + r')$') + # These headers are excluded from [build/include] and [build/include_order] # checks: @@ -402,20 +528,23 @@ _THIRD_PARTY_HEADERS_PATTERN = re.compile( r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$') +# Pattern for matching FileInfo.BaseName() against test file name +_test_suffixes = ['_test', '_regtest', '_unittest'] +_TEST_FILE_SUFFIX = '(' + '|'.join(_test_suffixes) + r')$' + +# Pattern that matches only complete whitespace, possibly across multiple lines. +_EMPTY_CONDITIONAL_BODY_PATTERN = re.compile(r'^\s*$', re.DOTALL) # Assertion macros. These are defined in base/logging.h and -# testing/base/gunit.h. Note that the _M versions need to come first -# for substring matching to work. +# testing/base/public/gunit.h. _CHECK_MACROS = [ 'DCHECK', 'CHECK', - 'EXPECT_TRUE_M', 'EXPECT_TRUE', - 'ASSERT_TRUE_M', 'ASSERT_TRUE', - 'EXPECT_FALSE_M', 'EXPECT_FALSE', - 'ASSERT_FALSE_M', 'ASSERT_FALSE', + 'EXPECT_TRUE', 'ASSERT_TRUE', + 'EXPECT_FALSE', 'ASSERT_FALSE', ] # Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE -_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS]) +_CHECK_REPLACEMENT = dict([(macro_var, {}) for macro_var in _CHECK_MACROS]) for op, replacement in [('==', 'EQ'), ('!=', 'NE'), ('>=', 'GE'), ('>', 'GT'), @@ -424,16 +553,12 @@ _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement - _CHECK_REPLACEMENT['EXPECT_TRUE_M'][op] = 'EXPECT_%s_M' % replacement - _CHECK_REPLACEMENT['ASSERT_TRUE_M'][op] = 'ASSERT_%s_M' % replacement for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'), ('>=', 'LT'), ('>', 'LE'), ('<=', 'GT'), ('<', 'GE')]: _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement - _CHECK_REPLACEMENT['EXPECT_FALSE_M'][op] = 'EXPECT_%s_M' % inv_replacement - _CHECK_REPLACEMENT['ASSERT_FALSE_M'][op] = 'ASSERT_%s_M' % inv_replacement # Alternative tokens and their replacements. For full list, see section 2.5 # Alternative tokens [lex.digraph] in the C++ standard. @@ -482,6 +607,12 @@ r'(?:\s+(volatile|__volatile__))?' r'\s*[{(]') +# Match strings that indicate we're working on a C (not C++) file. +_SEARCH_C_FILE = re.compile(r'\b(?:LINT_C_FILE|' + r'vim?:\s*.*(\s*|:)filetype=c(\s*|:|$))') + +# Match string that indicates we're working on a Linux Kernel file. +_SEARCH_KERNEL_FILE = re.compile(r'\b(?:LINT_KERNEL_FILE)') _regexp_compile_cache = {} @@ -493,16 +624,64 @@ # This is set by --root flag. _root = None +# The top level repository directory. If set, _root is calculated relative to +# this directory instead of the directory containing version control artifacts. +# This is set by the --repository flag. +_repository = None + +# Files to exclude from linting. This is set by the --exclude flag. +_excludes = None + +# Whether to supress PrintInfo messages +_quiet = False + # The allowed line length of files. # This is set by --linelength flag. _line_length = 80 -# The allowed extensions for file names -# This is set by --extensions flag. -_valid_extensions = set(['cc', 'h', 'cpp', 'cu', 'cuh']) +try: + xrange(1, 0) +except NameError: + # -- pylint: disable=redefined-builtin + xrange = range + +try: + unicode +except NameError: + # -- pylint: disable=redefined-builtin + basestring = unicode = str + +try: + long(2) +except NameError: + # -- pylint: disable=redefined-builtin + long = int + +if sys.version_info < (3,): + # -- pylint: disable=no-member + # BINARY_TYPE = str + itervalues = dict.itervalues + iteritems = dict.iteritems +else: + # BINARY_TYPE = bytes + itervalues = dict.values + iteritems = dict.items + +def unicode_escape_decode(x): + if sys.version_info < (3,): + return codecs.unicode_escape_decode(x)[0] + else: + return x + +# {str, bool}: a map from error categories to booleans which indicate if the +# category should be suppressed for every line. +_global_error_suppressions = {} + + + def ParseNolintSuppressions(filename, raw_line, linenum, error): - """Updates the global list of error-suppressions. + """Updates the global list of line error-suppressions. Parses any NOLINT comments on the current line, updating the global error_suppressions store. Reports an error if the NOLINT comment @@ -533,24 +712,45 @@ def ParseNolintSuppressions(filename, raw_line, linenum, error): 'Unknown NOLINT error category: %s' % category) +def ProcessGlobalSuppresions(lines): + """Updates the list of global error suppressions. + + Parses any lint directives in the file that have global effect. + + Args: + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + """ + for line in lines: + if _SEARCH_C_FILE.search(line): + for category in _DEFAULT_C_SUPPRESSED_CATEGORIES: + _global_error_suppressions[category] = True + if _SEARCH_KERNEL_FILE.search(line): + for category in _DEFAULT_KERNEL_SUPPRESSED_CATEGORIES: + _global_error_suppressions[category] = True + + def ResetNolintSuppressions(): """Resets the set of NOLINT suppressions to empty.""" _error_suppressions.clear() + _global_error_suppressions.clear() def IsErrorSuppressedByNolint(category, linenum): """Returns true if the specified error category is suppressed on this line. Consults the global error_suppressions map populated by - ParseNolintSuppressions/ResetNolintSuppressions. + ParseNolintSuppressions/ProcessGlobalSuppresions/ResetNolintSuppressions. Args: category: str, the category of the error. linenum: int, the current line number. Returns: - bool, True iff the error should be suppressed due to a NOLINT comment. + bool, True iff the error should be suppressed due to a NOLINT comment or + global suppression. """ - return (linenum in _error_suppressions.get(category, set()) or + return (_global_error_suppressions.get(category, False) or + linenum in _error_suppressions.get(category, set()) or linenum in _error_suppressions.get(None, set())) @@ -589,6 +789,11 @@ def Search(pattern, s): return _regexp_compile_cache[pattern].search(s) +def _IsSourceExtension(s): + """File extension (excluding dot) matches a source file extension.""" + return s in GetNonHeaderExtensions() + + class _IncludeState(object): """Tracks line numbers for includes, and the order in which includes appear. @@ -626,6 +831,8 @@ class _IncludeState(object): def __init__(self): self.include_list = [[]] + self._section = None + self._last_header = None self.ResetSection('') def FindHeader(self, header): @@ -769,9 +976,16 @@ def __init__(self): # output format: # "emacs" - format that emacs can parse (default) + # "eclipse" - format that eclipse can parse # "vs7" - format that Microsoft Visual Studio 7 can parse + # "junit" - format that Jenkins, Bamboo, etc can parse self.output_format = 'emacs' + # For JUnit output, save errors and failures until the end so that they + # can be written into the XML + self._junit_errors = [] + self._junit_failures = [] + def SetOutputFormat(self, output_format): """Sets the output format for errors.""" self.output_format = output_format @@ -840,10 +1054,69 @@ def IncrementErrorCount(self, category): def PrintErrorCounts(self): """Print a summary of errors by category, and the total.""" - for category, count in self.errors_by_category.iteritems(): - sys.stderr.write('Category \'%s\' errors found: %d\n' % + for category, count in sorted(iteritems(self.errors_by_category)): + self.PrintInfo('Category \'%s\' errors found: %d\n' % (category, count)) - sys.stderr.write('Total errors found: %d\n' % self.error_count) + if self.error_count > 0: + self.PrintInfo('Total errors found: %d\n' % self.error_count) + + def PrintInfo(self, message): + if not _quiet and self.output_format != 'junit': + sys.stderr.write(message) + + def PrintError(self, message): + if self.output_format == 'junit': + self._junit_errors.append(message) + else: + sys.stderr.write(message) + + def AddJUnitFailure(self, filename, linenum, message, category, confidence): + self._junit_failures.append((filename, linenum, message, category, + confidence)) + + def FormatJUnitXML(self): + num_errors = len(self._junit_errors) + num_failures = len(self._junit_failures) + + testsuite = xml.etree.ElementTree.Element('testsuite') + testsuite.attrib['name'] = 'cpplint' + testsuite.attrib['errors'] = str(num_errors) + testsuite.attrib['failures'] = str(num_failures) + + if num_errors == 0 and num_failures == 0: + testsuite.attrib['tests'] = str(1) + xml.etree.ElementTree.SubElement(testsuite, 'testcase', name='passed') + + else: + testsuite.attrib['tests'] = str(num_errors + num_failures) + if num_errors > 0: + testcase = xml.etree.ElementTree.SubElement(testsuite, 'testcase') + testcase.attrib['name'] = 'errors' + error = xml.etree.ElementTree.SubElement(testcase, 'error') + error.text = '\n'.join(self._junit_errors) + if num_failures > 0: + # Group failures by file + failed_file_order = [] + failures_by_file = {} + for failure in self._junit_failures: + failed_file = failure[0] + if failed_file not in failed_file_order: + failed_file_order.append(failed_file) + failures_by_file[failed_file] = [] + failures_by_file[failed_file].append(failure) + # Create a testcase for each file + for failed_file in failed_file_order: + failures = failures_by_file[failed_file] + testcase = xml.etree.ElementTree.SubElement(testsuite, 'testcase') + testcase.attrib['name'] = failed_file + failure = xml.etree.ElementTree.SubElement(testcase, 'failure') + template = '{0}: {1} [{2}] [{3}]' + texts = [template.format(f[1], f[2], f[3], f[4]) for f in failures] + failure.text = '\n'.join(texts) + + xml_decl = '\n' + return xml_decl + xml.etree.ElementTree.tostring(testsuite, 'utf-8').decode('utf-8') + _cpplint_state = _CppLintState() @@ -944,6 +1217,9 @@ def Check(self, error, filename, linenum): filename: The name of the current file. linenum: The number of the line to check. """ + if not self.in_a_function: + return + if Match(r'T(EST|est)', self.current_function): base_trigger = self._TEST_TRIGGER else: @@ -986,7 +1262,7 @@ def FullName(self): return os.path.abspath(self._filename).replace('\\', '/') def RepositoryName(self): - """FullName after removing the local path to the repository. + r"""FullName after removing the local path to the repository. If we have a real absolute path name here we can try to do something smart: detecting the root of the checkout and truncating /path/to/checkout from @@ -1000,6 +1276,20 @@ def RepositoryName(self): if os.path.exists(fullname): project_dir = os.path.dirname(fullname) + # If the user specified a repository path, it exists, and the file is + # contained in it, use the specified repository path + if _repository: + repo = FileInfo(_repository).FullName() + root_dir = project_dir + while os.path.exists(root_dir): + # allow case insensitive compare on Windows + if os.path.normcase(root_dir) == os.path.normcase(repo): + return os.path.relpath(fullname, root_dir).replace('\\', '/') + one_up_dir = os.path.dirname(root_dir) + if one_up_dir == root_dir: + break + root_dir = one_up_dir + if os.path.exists(os.path.join(project_dir, ".svn")): # If there's a .svn file in the current directory, we recursively look # up the directory tree for the top of the SVN checkout @@ -1014,12 +1304,13 @@ def RepositoryName(self): # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by # searching up from the current path. - root_dir = os.path.dirname(fullname) - while (root_dir != os.path.dirname(root_dir) and - not os.path.exists(os.path.join(root_dir, ".git")) and - not os.path.exists(os.path.join(root_dir, ".hg")) and - not os.path.exists(os.path.join(root_dir, ".svn"))): - root_dir = os.path.dirname(root_dir) + root_dir = current_dir = os.path.dirname(fullname) + while current_dir != os.path.dirname(current_dir): + if (os.path.exists(os.path.join(current_dir, ".git")) or + os.path.exists(os.path.join(current_dir, ".hg")) or + os.path.exists(os.path.join(current_dir, ".svn"))): + root_dir = current_dir + current_dir = os.path.dirname(current_dir) if (os.path.exists(os.path.join(root_dir, ".git")) or os.path.exists(os.path.join(root_dir, ".hg")) or @@ -1049,7 +1340,7 @@ def BaseName(self): return self.Split()[1] def Extension(self): - """File extension - text following the final period.""" + """File extension - text following the final period, includes that period.""" return self.Split()[2] def NoExtension(self): @@ -1058,7 +1349,7 @@ def NoExtension(self): def IsSource(self): """File has a source file extension.""" - return self.Extension()[1:] in ('c', 'cc', 'cpp', 'cxx') + return _IsSourceExtension(self.Extension()[1:]) def _ShouldPrintError(category, confidence, linenum): @@ -1114,15 +1405,18 @@ def Error(filename, linenum, category, confidence, message): if _ShouldPrintError(category, confidence, linenum): _cpplint_state.IncrementErrorCount(category) if _cpplint_state.output_format == 'vs7': - sys.stderr.write('%s(%s): %s [%s] [%d]\n' % ( + _cpplint_state.PrintError('%s(%s): warning: %s [%s] [%d]\n' % ( filename, linenum, message, category, confidence)) elif _cpplint_state.output_format == 'eclipse': sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % ( filename, linenum, message, category, confidence)) + elif _cpplint_state.output_format == 'junit': + _cpplint_state.AddJUnitFailure(filename, linenum, message, category, + confidence) else: - sys.stderr.write('%s:%s: %s [%s] [%d]\n' % ( - filename, linenum, message, category, confidence)) - + final_message = '%s:%s: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence) + sys.stderr.write(final_message) # Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard. _RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile( @@ -1204,8 +1498,18 @@ def CleanseRawStrings(raw_lines): while delimiter is None: # Look for beginning of a raw string. # See 2.14.15 [lex.string] for syntax. - matched = Match(r'^(.*)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line) - if matched: + # + # Once we have matched a raw string, we check the prefix of the + # line to make sure that the line is not part of a single line + # comment. It's done this way because we remove raw strings + # before removing comments as opposed to removing comments + # before removing raw strings. This is because there are some + # cpplint checks that requires the comments to be preserved, but + # we don't want to check comments that are inside raw strings. + matched = Match(r'^(.*?)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line) + if (matched and + not Match(r'^([^\'"]|\'(\\.|[^\'])*\'|"(\\.|[^"])*")*//', + matched.group(1))): delimiter = ')' + matched.group(2) + '"' end = matched.group(3).find(delimiter) @@ -1624,7 +1928,7 @@ def CheckForCopyright(filename, lines, error): # We'll say it should occur by line 10. Don't forget there's a # dummy line at the front. - for line in xrange(1, min(len(lines), 11)): + for line in range(1, min(len(lines), 11)): if re.search(r'Copyright', lines[line], re.I): break else: # means no copyright line was found error(filename, 0, 'legal/copyright', 5, @@ -1666,11 +1970,16 @@ def GetHeaderGuardCPPVariable(filename): filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename) # Replace 'c++' with 'cpp'. filename = filename.replace('C++', 'cpp').replace('c++', 'cpp') - + fileinfo = FileInfo(filename) file_path_from_root = fileinfo.RepositoryName() if _root: - file_path_from_root = re.sub('^' + _root + os.sep, '', file_path_from_root) + suffix = os.sep + # On Windows using directory separator will leave us with + # "bogus escape error" unless we properly escape regex. + if suffix == '\\': + suffix += '\\' + file_path_from_root = re.sub('^' + _root + suffix, '', file_path_from_root) return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_' @@ -1697,6 +2006,11 @@ def CheckForHeaderGuard(filename, clean_lines, error): if Search(r'//\s*NOLINT\(build/header_guard\)', i): return + # Allow pragma once instead of header guards + for i in raw_lines: + if Search(r'^\s*#pragma\s+once', i): + return + cppvar = GetHeaderGuardCPPVariable(filename) ifndef = '' @@ -1773,28 +2087,30 @@ def CheckForHeaderGuard(filename, clean_lines, error): def CheckHeaderFileIncluded(filename, include_state, error): - """Logs an error if a .cc file does not include its header.""" + """Logs an error if a source file does not include its header.""" # Do not check test files - if filename.endswith('_test.cc') or filename.endswith('_unittest.cc'): - return - fileinfo = FileInfo(filename) - headerfile = filename[0:len(filename) - 2] + 'h' - if not os.path.exists(headerfile): + if Search(_TEST_FILE_SUFFIX, fileinfo.BaseName()): return - headername = FileInfo(headerfile).RepositoryName() - first_include = 0 - for section_list in include_state.include_list: - for f in section_list: - if headername in f[0] or f[0] in headername: - return - if not first_include: - first_include = f[1] - error(filename, first_include, 'build/include', 5, - '%s should include its header file %s' % (fileinfo.RepositoryName(), - headername)) + for ext in GetHeaderExtensions(): + basefilename = filename[0:len(filename) - len(fileinfo.Extension())] + headerfile = basefilename + '.' + ext + if not os.path.exists(headerfile): + continue + headername = FileInfo(headerfile).RepositoryName() + first_include = None + for section_list in include_state.include_list: + for f in section_list: + if headername in f[0] or f[0] in headername: + return + if not first_include: + first_include = f[1] + + error(filename, first_include, 'build/include', 5, + '%s should include its header file %s' % (fileinfo.RepositoryName(), + headername)) def CheckForBadCharacters(filename, lines, error): @@ -1815,7 +2131,7 @@ def CheckForBadCharacters(filename, lines, error): error: The function to call with any errors found. """ for linenum, line in enumerate(lines): - if u'\ufffd' in line: + if unicode_escape_decode('\ufffd') in line: error(filename, linenum, 'readability/utf8', 5, 'Line contains invalid UTF-8 (or Unicode replacement character).') if '\0' in line: @@ -1997,7 +2313,8 @@ def IsForwardClassDeclaration(clean_lines, linenum): class _BlockInfo(object): """Stores information about a generic block of code.""" - def __init__(self, seen_open_brace): + def __init__(self, linenum, seen_open_brace): + self.starting_linenum = linenum self.seen_open_brace = seen_open_brace self.open_parentheses = 0 self.inline_asm = _NO_ASM @@ -2046,17 +2363,16 @@ def IsBlockInfo(self): class _ExternCInfo(_BlockInfo): """Stores information about an 'extern "C"' block.""" - def __init__(self): - _BlockInfo.__init__(self, True) + def __init__(self, linenum): + _BlockInfo.__init__(self, linenum, True) class _ClassInfo(_BlockInfo): """Stores information about a class.""" def __init__(self, name, class_or_struct, clean_lines, linenum): - _BlockInfo.__init__(self, False) + _BlockInfo.__init__(self, linenum, False) self.name = name - self.starting_linenum = linenum self.is_derived = False self.check_namespace_indentation = True if class_or_struct == 'struct': @@ -2124,9 +2440,8 @@ class _NamespaceInfo(_BlockInfo): """Stores information about a namespace.""" def __init__(self, name, linenum): - _BlockInfo.__init__(self, False) + _BlockInfo.__init__(self, linenum, False) self.name = name or '' - self.starting_linenum = linenum self.check_namespace_indentation = True def CheckEnd(self, filename, clean_lines, linenum, error): @@ -2145,7 +2460,7 @@ def CheckEnd(self, filename, clean_lines, linenum, error): # deciding what these nontrivial things are, so this check is # triggered by namespace size only, which works most of the time. if (linenum - self.starting_linenum < 10 - and not Match(r'};*\s*(//|/\*).*\bnamespace\b', line)): + and not Match(r'^\s*};*\s*(//|/\*).*\bnamespace\b', line)): return # Look for matching comment at end of namespace. @@ -2162,18 +2477,18 @@ def CheckEnd(self, filename, clean_lines, linenum, error): # expected namespace. if self.name: # Named namespace - if not Match((r'};*\s*(//|/\*).*\bnamespace\s+' + re.escape(self.name) + - r'[\*/\.\\\s]*$'), + if not Match((r'^\s*};*\s*(//|/\*).*\bnamespace\s+' + + re.escape(self.name) + r'[\*/\.\\\s]*$'), line): error(filename, linenum, 'readability/namespace', 5, 'Namespace should be terminated with "// namespace %s"' % self.name) else: # Anonymous namespace - if not Match(r'};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line): + if not Match(r'^\s*};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line): # If "// namespace anonymous" or "// anonymous namespace (more text)", # mention "// anonymous namespace" as an acceptable form - if Match(r'}.*\b(namespace anonymous|anonymous namespace)\b', line): + if Match(r'^\s*}.*\b(namespace anonymous|anonymous namespace)\b', line): error(filename, linenum, 'readability/namespace', 5, 'Anonymous namespace should be terminated with "// namespace"' ' or "// anonymous namespace"') @@ -2445,7 +2760,7 @@ def Update(self, filename, clean_lines, linenum, error): # class LOCKABLE API Object { # }; class_decl_match = Match( - r'^(\s*(?:template\s*<[\w\s<>,:]*>\s*)?' + r'^(\s*(?:template\s*<[\w\s<>,:=]*>\s*)?' r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))' r'(.*)$', line) if (class_decl_match and @@ -2512,9 +2827,9 @@ def Update(self, filename, clean_lines, linenum, error): if not self.SeenOpenBrace(): self.stack[-1].seen_open_brace = True elif Match(r'^extern\s*"[^"]*"\s*\{', line): - self.stack.append(_ExternCInfo()) + self.stack.append(_ExternCInfo(linenum)) else: - self.stack.append(_BlockInfo(True)) + self.stack.append(_BlockInfo(linenum, True)) if _MATCH_ASM.match(line): self.stack[-1].inline_asm = _BLOCK_ASM @@ -2626,7 +2941,8 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, r'\s+(register|static|extern|typedef)\b', line): error(filename, linenum, 'build/storage_class', 5, - 'Storage class (static, extern, typedef, etc) should be first.') + 'Storage-class specifier (static, extern, typedef, etc) should be ' + 'at the beginning of the declaration.') if Match(r'\s*#\s*endif\s*[^/\s]+', line): error(filename, linenum, 'build/endif_comment', 5, @@ -2665,9 +2981,7 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, base_classname = classinfo.name.split('::')[-1] # Look for single-argument constructors that aren't marked explicit. - # Technically a valid construct, but against style. Also look for - # non-single-argument constructors which are also technically valid, but - # strongly suggest something is wrong. + # Technically a valid construct, but against style. explicit_constructor_match = Match( r'\s+(?:inline\s+)?(explicit\s+)?(?:inline\s+)?%s\s*' r'\(((?:[^()]|\([^()]*\))*)\)' @@ -2694,6 +3008,7 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, constructor_args[i] = constructor_arg i += 1 + variadic_args = [arg for arg in constructor_args if '&&...' in arg] defaulted_args = [arg for arg in constructor_args if '=' in arg] noarg_constructor = (not constructor_args or # empty arg list # 'void' arg specifier @@ -2704,7 +3019,10 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, # all but at most one arg defaulted (len(constructor_args) >= 1 and not noarg_constructor and - len(defaulted_args) >= len(constructor_args) - 1)) + len(defaulted_args) >= len(constructor_args) - 1) or + # variadic arguments with zero or one argument + (len(constructor_args) <= 2 and + len(variadic_args) >= 1)) initializer_list_constructor = bool( onearg_constructor and Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0])) @@ -2717,7 +3035,7 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, onearg_constructor and not initializer_list_constructor and not copy_constructor): - if defaulted_args: + if defaulted_args or variadic_args: error(filename, linenum, 'runtime/explicit', 5, 'Constructors callable with one argument ' 'should be marked explicit.') @@ -2728,10 +3046,6 @@ def CheckForNonStandardConstructs(filename, clean_lines, linenum, if noarg_constructor: error(filename, linenum, 'runtime/explicit', 5, 'Zero-parameter constructors should not be marked explicit.') - else: - error(filename, linenum, 'runtime/explicit', 0, - 'Constructors that require multiple arguments ' - 'should not be marked explicit.') def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error): @@ -2786,6 +3100,7 @@ def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error): error(filename, linenum, 'whitespace/parens', 2, 'Extra space after (') if (Search(r'\w\s+\(', fncall) and + not Search(r'_{0,2}asm_{0,2}\s+_{0,2}volatile_{0,2}\s+\(', fncall) and not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and not Search(r'\bcase\s+\(', fncall)): @@ -2844,7 +3159,7 @@ def CheckForFunctionLengths(filename, clean_lines, linenum, """Reports for long function bodies. For an overview why this is done, see: - http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions + https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions Uses a simplistic algorithm assuming other style guidelines (especially spacing) are followed. @@ -2879,7 +3194,7 @@ def CheckForFunctionLengths(filename, clean_lines, linenum, if starting_func: body_found = False - for start_linenum in xrange(linenum, clean_lines.NumLines()): + for start_linenum in range(linenum, clean_lines.NumLines()): start_line = lines[start_linenum] joined_line += ' ' + start_line.lstrip() if Search(r'(;|})', start_line): # Declarations and trivial functions @@ -2923,9 +3238,7 @@ def CheckComment(line, filename, linenum, next_line_start, error): commentpos = line.find('//') if commentpos != -1: # Check if the // may be in quotes. If so, ignore it - # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison - if (line.count('"', 0, commentpos) - - line.count('\\"', 0, commentpos)) % 2 == 0: # not in quotes + if re.sub(r'\\.', '', line[0:commentpos]).count('"') % 2 == 0: # Allow one space for new scopes, two spaces otherwise: if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and ((commentpos >= 1 and @@ -3174,8 +3487,8 @@ def CheckOperatorSpacing(filename, clean_lines, linenum, error): # macro context and don't do any checks. This avoids false # positives. # - # Note that && is not included here. Those are checked separately - # in CheckRValueReference + # Note that && is not included here. This is because there are too + # many false positives due to RValue references. match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line) if match: error(filename, linenum, 'whitespace/operators', 3, @@ -3209,7 +3522,7 @@ def CheckOperatorSpacing(filename, clean_lines, linenum, error): # # We also allow operators following an opening parenthesis, since # those tend to be macros that deal with operators. - match = Search(r'(operator|[^\s(<])(?:L|UL|ULL|l|ul|ull)?<<([^\s,=<])', line) + match = Search(r'(operator|[^\s(<])(?:L|UL|LL|ULL|l|ul|ll|ull)?<<([^\s,=<])', line) if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and not (match.group(1) == 'operator' and match.group(2) == ';')): error(filename, linenum, 'whitespace/operators', 3, @@ -3313,22 +3626,90 @@ def CheckCommaSpacing(filename, clean_lines, linenum, error): 'Missing space after ;') -def CheckBracesSpacing(filename, clean_lines, linenum, error): +def _IsType(clean_lines, nesting_state, expr): + """Check if expression looks like a type name, returns true if so. + + Args: + clean_lines: A CleansedLines instance containing the file. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + expr: The expression to check. + Returns: + True, if token looks like a type. + """ + # Keep only the last token in the expression + last_word = Match(r'^.*(\b\S+)$', expr) + if last_word: + token = last_word.group(1) + else: + token = expr + + # Match native types and stdint types + if _TYPES.match(token): + return True + + # Try a bit harder to match templated types. Walk up the nesting + # stack until we find something that resembles a typename + # declaration for what we are looking for. + typename_pattern = (r'\b(?:typename|class|struct)\s+' + re.escape(token) + + r'\b') + block_index = len(nesting_state.stack) - 1 + while block_index >= 0: + if isinstance(nesting_state.stack[block_index], _NamespaceInfo): + return False + + # Found where the opening brace is. We want to scan from this + # line up to the beginning of the function, minus a few lines. + # template + # class C + # : public ... { // start scanning here + last_line = nesting_state.stack[block_index].starting_linenum + + next_block_start = 0 + if block_index > 0: + next_block_start = nesting_state.stack[block_index - 1].starting_linenum + first_line = last_line + while first_line >= next_block_start: + if clean_lines.elided[first_line].find('template') >= 0: + break + first_line -= 1 + if first_line < next_block_start: + # Didn't find any "template" keyword before reaching the next block, + # there are probably no template things to check for this block + block_index -= 1 + continue + + # Look for typename in the specified range + for i in xrange(first_line, last_line + 1, 1): + if Search(typename_pattern, clean_lines.elided[i]): + return True + block_index -= 1 + + return False + + +def CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error): """Checks for horizontal spacing near commas. Args: filename: The name of the current file. clean_lines: A CleansedLines instance containing the file. linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. error: The function to call with any errors found. """ line = clean_lines.elided[linenum] # Except after an opening paren, or after another opening brace (in case of # an initializer list, for instance), you should have spaces before your - # braces. And since you should never have braces at the beginning of a line, - # this is an easy test. + # braces when they are delimiting blocks, classes, namespaces etc. + # And since you should never have braces at the beginning of a line, + # this is an easy test. Except that braces used for initialization don't + # follow the same rule; we often don't want spaces before those. match = Match(r'^(.*[^ ({>]){', line) + if match: # Try a bit harder to check for brace initialization. This # happens in one of the following forms: @@ -3358,6 +3739,7 @@ def CheckBracesSpacing(filename, clean_lines, linenum, error): # There is a false negative with this approach if people inserted # spurious semicolons, e.g. "if (cond){};", but we will catch the # spurious semicolon with a separate check. + leading_text = match.group(1) (endline, endlinenum, endpos) = CloseExpression( clean_lines, linenum, len(match.group(1))) trailing_text = '' @@ -3366,7 +3748,11 @@ def CheckBracesSpacing(filename, clean_lines, linenum, error): for offset in xrange(endlinenum + 1, min(endlinenum + 3, clean_lines.NumLines() - 1)): trailing_text += clean_lines.elided[offset] - if not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text): + # We also suppress warnings for `uint64_t{expression}` etc., as the style + # guide recommends brace initialization for integral types to avoid + # overflow/truncation. + if (not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text) + and not _IsType(clean_lines, nesting_state, leading_text)): error(filename, linenum, 'whitespace/braces', 5, 'Missing space before {') @@ -3409,406 +3795,6 @@ def IsDecltype(clean_lines, linenum, column): return True return False - -def IsTemplateParameterList(clean_lines, linenum, column): - """Check if the token ending on (linenum, column) is the end of template<>. - - Args: - clean_lines: A CleansedLines instance containing the file. - linenum: the number of the line to check. - column: end column of the token to check. - Returns: - True if this token is end of a template parameter list, False otherwise. - """ - (_, startline, startpos) = ReverseCloseExpression( - clean_lines, linenum, column) - if (startpos > -1 and - Search(r'\btemplate\s*$', clean_lines.elided[startline][0:startpos])): - return True - return False - - -def IsRValueType(typenames, clean_lines, nesting_state, linenum, column): - """Check if the token ending on (linenum, column) is a type. - - Assumes that text to the right of the column is "&&" or a function - name. - - Args: - typenames: set of type names from template-argument-list. - clean_lines: A CleansedLines instance containing the file. - nesting_state: A NestingState instance which maintains information about - the current stack of nested blocks being parsed. - linenum: the number of the line to check. - column: end column of the token to check. - Returns: - True if this token is a type, False if we are not sure. - """ - prefix = clean_lines.elided[linenum][0:column] - - # Get one word to the left. If we failed to do so, this is most - # likely not a type, since it's unlikely that the type name and "&&" - # would be split across multiple lines. - match = Match(r'^(.*)(\b\w+|[>*)&])\s*$', prefix) - if not match: - return False - - # Check text following the token. If it's "&&>" or "&&," or "&&...", it's - # most likely a rvalue reference used inside a template. - suffix = clean_lines.elided[linenum][column:] - if Match(r'&&\s*(?:[>,]|\.\.\.)', suffix): - return True - - # Check for known types and end of templates: - # int&& variable - # vector&& variable - # - # Because this function is called recursively, we also need to - # recognize pointer and reference types: - # int* Function() - # int& Function() - if (match.group(2) in typenames or - match.group(2) in ['char', 'char16_t', 'char32_t', 'wchar_t', 'bool', - 'short', 'int', 'long', 'signed', 'unsigned', - 'float', 'double', 'void', 'auto', '>', '*', '&']): - return True - - # If we see a close parenthesis, look for decltype on the other side. - # decltype would unambiguously identify a type, anything else is - # probably a parenthesized expression and not a type. - if match.group(2) == ')': - return IsDecltype( - clean_lines, linenum, len(match.group(1)) + len(match.group(2)) - 1) - - # Check for casts and cv-qualifiers. - # match.group(1) remainder - # -------------- --------- - # const_cast< type&& - # const type&& - # type const&& - if Search(r'\b(?:const_cast\s*<|static_cast\s*<|dynamic_cast\s*<|' - r'reinterpret_cast\s*<|\w+\s)\s*$', - match.group(1)): - return True - - # Look for a preceding symbol that might help differentiate the context. - # These are the cases that would be ambiguous: - # match.group(1) remainder - # -------------- --------- - # Call ( expression && - # Declaration ( type&& - # sizeof ( type&& - # if ( expression && - # while ( expression && - # for ( type&& - # for( ; expression && - # statement ; type&& - # block { type&& - # constructor { expression && - start = linenum - line = match.group(1) - match_symbol = None - while start >= 0: - # We want to skip over identifiers and commas to get to a symbol. - # Commas are skipped so that we can find the opening parenthesis - # for function parameter lists. - match_symbol = Match(r'^(.*)([^\w\s,])[\w\s,]*$', line) - if match_symbol: - break - start -= 1 - line = clean_lines.elided[start] - - if not match_symbol: - # Probably the first statement in the file is an rvalue reference - return True - - if match_symbol.group(2) == '}': - # Found closing brace, probably an indicate of this: - # block{} type&& - return True - - if match_symbol.group(2) == ';': - # Found semicolon, probably one of these: - # for(; expression && - # statement; type&& - - # Look for the previous 'for(' in the previous lines. - before_text = match_symbol.group(1) - for i in xrange(start - 1, max(start - 6, 0), -1): - before_text = clean_lines.elided[i] + before_text - if Search(r'for\s*\([^{};]*$', before_text): - # This is the condition inside a for-loop - return False - - # Did not find a for-init-statement before this semicolon, so this - # is probably a new statement and not a condition. - return True - - if match_symbol.group(2) == '{': - # Found opening brace, probably one of these: - # block{ type&& = ... ; } - # constructor{ expression && expression } - - # Look for a closing brace or a semicolon. If we see a semicolon - # first, this is probably a rvalue reference. - line = clean_lines.elided[start][0:len(match_symbol.group(1)) + 1] - end = start - depth = 1 - while True: - for ch in line: - if ch == ';': - return True - elif ch == '{': - depth += 1 - elif ch == '}': - depth -= 1 - if depth == 0: - return False - end += 1 - if end >= clean_lines.NumLines(): - break - line = clean_lines.elided[end] - # Incomplete program? - return False - - if match_symbol.group(2) == '(': - # Opening parenthesis. Need to check what's to the left of the - # parenthesis. Look back one extra line for additional context. - before_text = match_symbol.group(1) - if linenum > 1: - before_text = clean_lines.elided[linenum - 1] + before_text - before_text = match_symbol.group(1) - - # Patterns that are likely to be types: - # [](type&& - # for (type&& - # sizeof(type&& - # operator=(type&& - # - if Search(r'(?:\]|\bfor|\bsizeof|\boperator\s*\S+\s*)\s*$', before_text): - return True - - # Patterns that are likely to be expressions: - # if (expression && - # while (expression && - # : initializer(expression && - # , initializer(expression && - # ( FunctionCall(expression && - # + FunctionCall(expression && - # + (expression && - # - # The last '+' represents operators such as '+' and '-'. - if Search(r'(?:\bif|\bwhile|[-+=%^(]*>)?\s*$', - match_symbol.group(1)) - if match_func: - # Check for constructors, which don't have return types. - if Search(r'\b(?:explicit|inline)$', match_func.group(1)): - return True - implicit_constructor = Match(r'\s*(\w+)\((?:const\s+)?(\w+)', prefix) - if (implicit_constructor and - implicit_constructor.group(1) == implicit_constructor.group(2)): - return True - return IsRValueType(typenames, clean_lines, nesting_state, linenum, - len(match_func.group(1))) - - # Nothing before the function name. If this is inside a block scope, - # this is probably a function call. - return not (nesting_state.previous_stack_top and - nesting_state.previous_stack_top.IsBlockInfo()) - - if match_symbol.group(2) == '>': - # Possibly a closing bracket, check that what's on the other side - # looks like the start of a template. - return IsTemplateParameterList( - clean_lines, start, len(match_symbol.group(1))) - - # Some other symbol, usually something like "a=b&&c". This is most - # likely not a type. - return False - - -def IsDeletedOrDefault(clean_lines, linenum): - """Check if current constructor or operator is deleted or default. - - Args: - clean_lines: A CleansedLines instance containing the file. - linenum: The number of the line to check. - Returns: - True if this is a deleted or default constructor. - """ - open_paren = clean_lines.elided[linenum].find('(') - if open_paren < 0: - return False - (close_line, _, close_paren) = CloseExpression( - clean_lines, linenum, open_paren) - if close_paren < 0: - return False - return Match(r'\s*=\s*(?:delete|default)\b', close_line[close_paren:]) - - -def IsRValueAllowed(clean_lines, linenum, typenames): - """Check if RValue reference is allowed on a particular line. - - Args: - clean_lines: A CleansedLines instance containing the file. - linenum: The number of the line to check. - typenames: set of type names from template-argument-list. - Returns: - True if line is within the region where RValue references are allowed. - """ - # Allow region marked by PUSH/POP macros - for i in xrange(linenum, 0, -1): - line = clean_lines.elided[i] - if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line): - if not line.endswith('PUSH'): - return False - for j in xrange(linenum, clean_lines.NumLines(), 1): - line = clean_lines.elided[j] - if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line): - return line.endswith('POP') - - # Allow operator= - line = clean_lines.elided[linenum] - if Search(r'\boperator\s*=\s*\(', line): - return IsDeletedOrDefault(clean_lines, linenum) - - # Allow constructors - match = Match(r'\s*(?:[\w<>]+::)*([\w<>]+)\s*::\s*([\w<>]+)\s*\(', line) - if match and match.group(1) == match.group(2): - return IsDeletedOrDefault(clean_lines, linenum) - if Search(r'\b(?:explicit|inline)\s+[\w<>]+\s*\(', line): - return IsDeletedOrDefault(clean_lines, linenum) - - if Match(r'\s*[\w<>]+\s*\(', line): - previous_line = 'ReturnType' - if linenum > 0: - previous_line = clean_lines.elided[linenum - 1] - if Match(r'^\s*$', previous_line) or Search(r'[{}:;]\s*$', previous_line): - return IsDeletedOrDefault(clean_lines, linenum) - - # Reject types not mentioned in template-argument-list - while line: - match = Match(r'^.*?(\w+)\s*&&(.*)$', line) - if not match: - break - if match.group(1) not in typenames: - return False - line = match.group(2) - - # All RValue types that were in template-argument-list should have - # been removed by now. Those were allowed, assuming that they will - # be forwarded. - # - # If there are no remaining RValue types left (i.e. types that were - # not found in template-argument-list), flag those as not allowed. - return line.find('&&') < 0 - - -def GetTemplateArgs(clean_lines, linenum): - """Find list of template arguments associated with this function declaration. - - Args: - clean_lines: A CleansedLines instance containing the file. - linenum: Line number containing the start of the function declaration, - usually one line after the end of the template-argument-list. - Returns: - Set of type names, or empty set if this does not appear to have - any template parameters. - """ - # Find start of function - func_line = linenum - while func_line > 0: - line = clean_lines.elided[func_line] - if Match(r'^\s*$', line): - return set() - if line.find('(') >= 0: - break - func_line -= 1 - if func_line == 0: - return set() - - # Collapse template-argument-list into a single string - argument_list = '' - match = Match(r'^(\s*template\s*)<', clean_lines.elided[func_line]) - if match: - # template-argument-list on the same line as function name - start_col = len(match.group(1)) - _, end_line, end_col = CloseExpression(clean_lines, func_line, start_col) - if end_col > -1 and end_line == func_line: - start_col += 1 # Skip the opening bracket - argument_list = clean_lines.elided[func_line][start_col:end_col] - - elif func_line > 1: - # template-argument-list one line before function name - match = Match(r'^(.*)>\s*$', clean_lines.elided[func_line - 1]) - if match: - end_col = len(match.group(1)) - _, start_line, start_col = ReverseCloseExpression( - clean_lines, func_line - 1, end_col) - if start_col > -1: - start_col += 1 # Skip the opening bracket - while start_line < func_line - 1: - argument_list += clean_lines.elided[start_line][start_col:] - start_col = 0 - start_line += 1 - argument_list += clean_lines.elided[func_line - 1][start_col:end_col] - - if not argument_list: - return set() - - # Extract type names - typenames = set() - while True: - match = Match(r'^[,\s]*(?:typename|class)(?:\.\.\.)?\s+(\w+)(.*)$', - argument_list) - if not match: - break - typenames.add(match.group(1)) - argument_list = match.group(2) - return typenames - - -def CheckRValueReference(filename, clean_lines, linenum, nesting_state, error): - """Check for rvalue references. - - Args: - filename: The name of the current file. - clean_lines: A CleansedLines instance containing the file. - linenum: The number of the line to check. - nesting_state: A NestingState instance which maintains information about - the current stack of nested blocks being parsed. - error: The function to call with any errors found. - """ - # Find lines missing spaces around &&. - # TODO(unknown): currently we don't check for rvalue references - # with spaces surrounding the && to avoid false positives with - # boolean expressions. - line = clean_lines.elided[linenum] - match = Match(r'^(.*\S)&&', line) - if not match: - match = Match(r'(.*)&&\S', line) - if (not match) or '(&&)' in line or Search(r'\boperator\s*$', match.group(1)): - return - - # Either poorly formed && or an rvalue reference, check the context - # to get a more accurate error message. Mostly we want to determine - # if what's to the left of "&&" is a type or not. - typenames = GetTemplateArgs(clean_lines, linenum) - and_pos = len(match.group(1)) - if IsRValueType(typenames, clean_lines, nesting_state, linenum, and_pos): - if not IsRValueAllowed(clean_lines, linenum, typenames): - error(filename, linenum, 'build/c++11', 3, - 'RValue references are an unapproved C++ feature.') - else: - error(filename, linenum, 'whitespace/operators', 3, - 'Missing spaces around &&') - - def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error): """Checks for additional blank line issues related to sections. @@ -3906,10 +3892,13 @@ def CheckBraces(filename, clean_lines, linenum, error): # used for brace initializers inside function calls. We don't detect this # perfectly: we just don't complain if the last non-whitespace character on # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the - # previous line starts a preprocessor block. + # previous line starts a preprocessor block. We also allow a brace on the + # following line if it is part of an array initialization and would not fit + # within the 80 character limit of the preceding line. prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] if (not Search(r'[,;:}{(]\s*$', prevline) and - not Match(r'\s*#', prevline)): + not Match(r'\s*#', prevline) and + not (GetLineWidth(prevline) > _line_length - 2 and '[]' in prevline)): error(filename, linenum, 'whitespace/braces', 4, '{ should almost always be at the end of the previous line') @@ -4085,13 +4074,14 @@ def CheckTrailingSemicolon(filename, clean_lines, linenum, error): # In addition to macros, we also don't want to warn on # - Compound literals # - Lambdas - # - alignas specifier with anonymous structs: + # - alignas specifier with anonymous structs + # - decltype closing_brace_pos = match.group(1).rfind(')') opening_parenthesis = ReverseCloseExpression( clean_lines, linenum, closing_brace_pos) if opening_parenthesis[2] > -1: line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]] - macro = Search(r'\b([A-Z_]+)\s*$', line_prefix) + macro = Search(r'\b([A-Z_][A-Z0-9_]*)\s*$', line_prefix) func = Match(r'^(.*\])\s*$', line_prefix) if ((macro and macro.group(1) not in ( @@ -4100,6 +4090,7 @@ def CheckTrailingSemicolon(filename, clean_lines, linenum, error): 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or + Search(r'\bdecltype$', line_prefix) or Search(r'\s+=\s*$', line_prefix)): match = None if (match and @@ -4136,6 +4127,14 @@ def CheckTrailingSemicolon(filename, clean_lines, linenum, error): # outputting warnings for the matching closing brace, if there are # nested blocks with trailing semicolons, we will get the error # messages in reversed order. + + # We need to check the line forward for NOLINT + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[endlinenum-1], endlinenum-1, + error) + ParseNolintSuppressions(filename, raw_lines[endlinenum], endlinenum, + error) + error(filename, endlinenum, 'readability/braces', 4, "You don't need a ; after a }") @@ -4159,7 +4158,7 @@ def CheckEmptyBlockBody(filename, clean_lines, linenum, error): line = clean_lines.elided[linenum] matched = Match(r'\s*(for|while|if)\s*\(', line) if matched: - # Find the end of the conditional expression + # Find the end of the conditional expression. (end_line, end_linenum, end_pos) = CloseExpression( clean_lines, linenum, line.find('(')) @@ -4174,6 +4173,75 @@ def CheckEmptyBlockBody(filename, clean_lines, linenum, error): error(filename, end_linenum, 'whitespace/empty_loop_body', 5, 'Empty loop bodies should use {} or continue') + # Check for if statements that have completely empty bodies (no comments) + # and no else clauses. + if end_pos >= 0 and matched.group(1) == 'if': + # Find the position of the opening { for the if statement. + # Return without logging an error if it has no brackets. + opening_linenum = end_linenum + opening_line_fragment = end_line[end_pos:] + # Loop until EOF or find anything that's not whitespace or opening {. + while not Search(r'^\s*\{', opening_line_fragment): + if Search(r'^(?!\s*$)', opening_line_fragment): + # Conditional has no brackets. + return + opening_linenum += 1 + if opening_linenum == len(clean_lines.elided): + # Couldn't find conditional's opening { or any code before EOF. + return + opening_line_fragment = clean_lines.elided[opening_linenum] + # Set opening_line (opening_line_fragment may not be entire opening line). + opening_line = clean_lines.elided[opening_linenum] + + # Find the position of the closing }. + opening_pos = opening_line_fragment.find('{') + if opening_linenum == end_linenum: + # We need to make opening_pos relative to the start of the entire line. + opening_pos += end_pos + (closing_line, closing_linenum, closing_pos) = CloseExpression( + clean_lines, opening_linenum, opening_pos) + if closing_pos < 0: + return + + # Now construct the body of the conditional. This consists of the portion + # of the opening line after the {, all lines until the closing line, + # and the portion of the closing line before the }. + if (clean_lines.raw_lines[opening_linenum] != + CleanseComments(clean_lines.raw_lines[opening_linenum])): + # Opening line ends with a comment, so conditional isn't empty. + return + if closing_linenum > opening_linenum: + # Opening line after the {. Ignore comments here since we checked above. + bodylist = list(opening_line[opening_pos+1:]) + # All lines until closing line, excluding closing line, with comments. + bodylist.extend(clean_lines.raw_lines[opening_linenum+1:closing_linenum]) + # Closing line before the }. Won't (and can't) have comments. + bodylist.append(clean_lines.elided[closing_linenum][:closing_pos-1]) + body = '\n'.join(bodylist) + else: + # If statement has brackets and fits on a single line. + body = opening_line[opening_pos+1:closing_pos-1] + + # Check if the body is empty + if not _EMPTY_CONDITIONAL_BODY_PATTERN.search(body): + return + # The body is empty. Now make sure there's not an else clause. + current_linenum = closing_linenum + current_line_fragment = closing_line[closing_pos:] + # Loop until EOF or find anything that's not whitespace or else clause. + while Search(r'^\s*$|^(?=\s*else)', current_line_fragment): + if Search(r'^(?=\s*else)', current_line_fragment): + # Found an else clause, so don't log an error. + return + current_linenum += 1 + if current_linenum == len(clean_lines.elided): + break + current_line_fragment = clean_lines.elided[current_linenum] + + # The body is empty and there's no else clause until EOF or other code. + error(filename, end_linenum, 'whitespace/empty_if_body', 4, + ('If statement had no body and no else clause')) + def FindCheckMacro(line): """Find a replaceable CHECK-like macro. @@ -4393,6 +4461,7 @@ def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, # raw strings, raw_lines = clean_lines.lines_without_raw_strings line = raw_lines[linenum] + prev = raw_lines[linenum - 1] if linenum > 0 else '' if line.find('\t') != -1: error(filename, linenum, 'whitespace/tab', 1, @@ -4416,22 +4485,27 @@ def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, cleansed_line = clean_lines.elided[linenum] while initial_spaces < len(line) and line[initial_spaces] == ' ': initial_spaces += 1 - if line and line[-1].isspace(): - error(filename, linenum, 'whitespace/end_of_line', 4, - 'Line ends in whitespace. Consider deleting these extra spaces.') # There are certain situations we allow one space, notably for # section labels, and also lines containing multi-line raw strings. - elif ((initial_spaces == 1 or initial_spaces == 3) and - not Match(scope_or_label_pattern, cleansed_line) and - not (clean_lines.raw_lines[linenum] != line and - Match(r'^\s*""', line))): + # We also don't check for lines that look like continuation lines + # (of lines ending in double quotes, commas, equals, or angle brackets) + # because the rules for how to indent those are non-trivial. + if (not Search(r'[",=><] *$', prev) and + (initial_spaces == 1 or initial_spaces == 3) and + not Match(scope_or_label_pattern, cleansed_line) and + not (clean_lines.raw_lines[linenum] != line and + Match(r'^\s*""', line))): error(filename, linenum, 'whitespace/indent', 3, 'Weird number of spaces at line-start. ' 'Are you using a 2-space indent?') + if line and line[-1].isspace(): + error(filename, linenum, 'whitespace/end_of_line', 4, + 'Line ends in whitespace. Consider deleting these extra spaces.') + # Check if the line is a header guard. is_header_guard = False - if file_extension == 'h': + if file_extension in GetHeaderExtensions(): cppvar = GetHeaderGuardCPPVariable(filename) if (line.startswith('#ifndef %s' % cppvar) or line.startswith('#define %s' % cppvar) or @@ -4445,20 +4519,23 @@ def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, # # The "$Id:...$" comment may also get very long without it being the # developers fault. + # + # Doxygen documentation copying can get pretty long when using an overloaded + # function declaration if (not line.startswith('#include') and not is_header_guard and not Match(r'^\s*//.*http(s?)://\S*$', line) and - not Match(r'^// \$Id:.*#[0-9]+ \$$', line)): + not Match(r'^\s*//\s*[^\s]*$', line) and + not Match(r'^// \$Id:.*#[0-9]+ \$$', line) and + not Match(r'^\s*/// [@\\](copydoc|copydetails|copybrief) .*$', line)): line_width = GetLineWidth(line) - extended_length = int((_line_length * 1.25)) - if line_width > extended_length: - error(filename, linenum, 'whitespace/line_length', 4, - 'Lines should very rarely be longer than %i characters' % - extended_length) - elif line_width > _line_length: + if line_width > _line_length: error(filename, linenum, 'whitespace/line_length', 2, 'Lines should be <= %i characters long' % _line_length) if (cleansed_line.count(';') > 1 and + # allow simple single line lambdas + not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}\n\r]*\}', + line) and # for loops are allowed two ;'s (and may run over two lines). cleansed_line.find('for') == -1 and (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or @@ -4479,9 +4556,8 @@ def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, CheckOperatorSpacing(filename, clean_lines, linenum, error) CheckParenthesisSpacing(filename, clean_lines, linenum, error) CheckCommaSpacing(filename, clean_lines, linenum, error) - CheckBracesSpacing(filename, clean_lines, linenum, error) + CheckBracesSpacing(filename, clean_lines, linenum, nesting_state, error) CheckSpacingForFunctionCall(filename, clean_lines, linenum, error) - CheckRValueReference(filename, clean_lines, linenum, nesting_state, error) CheckCheck(filename, clean_lines, linenum, error) CheckAltTokens(filename, clean_lines, linenum, error) classinfo = nesting_state.InnermostClass() @@ -4517,31 +4593,17 @@ def _DropCommonSuffixes(filename): Returns: The filename with the common suffix removed. """ - for suffix in ('test.cc', 'regtest.cc', 'unittest.cc', - 'inl.h', 'impl.h', 'internal.h'): + for suffix in itertools.chain( + ('%s.%s' % (test_suffix.lstrip('_'), ext) + for test_suffix, ext in itertools.product(_test_suffixes, GetNonHeaderExtensions())), + ('%s.%s' % (suffix, ext) + for suffix, ext in itertools.product(['inl', 'imp', 'internal'], GetHeaderExtensions()))): if (filename.endswith(suffix) and len(filename) > len(suffix) and filename[-len(suffix) - 1] in ('-', '_')): return filename[:-len(suffix) - 1] return os.path.splitext(filename)[0] -def _IsTestFilename(filename): - """Determines if the given filename has a suffix that identifies it as a test. - - Args: - filename: The input filename. - - Returns: - True if 'filename' looks like a test, False otherwise. - """ - if (filename.endswith('_test.cc') or - filename.endswith('_unittest.cc') or - filename.endswith('_regtest.cc')): - return True - else: - return False - - def _ClassifyInclude(fileinfo, include, is_system): """Figures out what kind of header 'include' is. @@ -4570,6 +4632,10 @@ def _ClassifyInclude(fileinfo, include, is_system): # those already checked for above. is_cpp_h = include in _CPP_HEADERS + # Headers with C++ extensions shouldn't be considered C system headers + if is_system and os.path.splitext(include)[1] in ['.hpp', '.hxx', '.h++']: + is_system = False + if is_system: if is_cpp_h: return _CPP_SYS_HEADER @@ -4582,9 +4648,11 @@ def _ClassifyInclude(fileinfo, include, is_system): target_dir, target_base = ( os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName()))) include_dir, include_base = os.path.split(_DropCommonSuffixes(include)) + target_dir_pub = os.path.normpath(target_dir + '/../public') + target_dir_pub = target_dir_pub.replace('\\', '/') if target_base == include_base and ( include_dir == target_dir or - include_dir == os.path.normpath(target_dir + '/../public')): + include_dir == target_dir_pub): return _LIKELY_MY_HEADER # If the target and include share some initial basename @@ -4628,7 +4696,7 @@ def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): # naming convention but not the include convention. match = Match(r'#include\s*"([^/]+\.h)"', line) if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)): - error(filename, linenum, 'build/include', 4, + error(filename, linenum, 'build/include_subdir', 4, 'Include the directory when naming .h files') # we shouldn't include a file more than once. actually, there are a @@ -4643,11 +4711,16 @@ def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): error(filename, linenum, 'build/include', 4, '"%s" already included at %s:%s' % (include, filename, duplicate_line)) - elif (include.endswith('.cc') and + return + + for extension in GetNonHeaderExtensions(): + if (include.endswith('.' + extension) and os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)): - error(filename, linenum, 'build/include', 4, - 'Do not include .cc files from other packages') - elif not _THIRD_PARTY_HEADERS_PATTERN.match(include): + error(filename, linenum, 'build/include', 4, + 'Do not include .' + extension + ' files from other packages') + return + + if not _THIRD_PARTY_HEADERS_PATTERN.match(include): include_state.include_list[-1].append((include, linenum)) # We want to ensure that headers appear in the right order: @@ -4701,7 +4774,7 @@ def _GetTextInside(text, start_pattern): # Give opening punctuations to get the matching close-punctuations. matching_punctuation = {'(': ')', '{': '}', '[': ']'} - closing_punctuation = set(matching_punctuation.itervalues()) + closing_punctuation = set(itervalues(matching_punctuation)) # Find the position to start extracting text. match = re.search(start_pattern, text, re.M) @@ -4756,6 +4829,9 @@ def _GetTextInside(text, start_pattern): _RE_PATTERN_CONST_REF_PARAM = ( r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT + r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')') +# Stream types. +_RE_PATTERN_REF_STREAM_PARAM = ( + r'(?:.*stream\s*&\s*' + _RE_PATTERN_IDENT + r')') def CheckLanguage(filename, clean_lines, linenum, file_extension, @@ -4792,15 +4868,13 @@ def CheckLanguage(filename, clean_lines, linenum, file_extension, if match: include_state.ResetSection(match.group(1)) - # Make Windows paths like Unix. - fullname = os.path.abspath(filename).replace('\\', '/') - + # Perform other checks now that we are sure that this is not an include line CheckCasts(filename, clean_lines, linenum, error) CheckGlobalStatic(filename, clean_lines, linenum, error) CheckPrintf(filename, clean_lines, linenum, error) - if file_extension == 'h': + if file_extension in GetHeaderExtensions(): # TODO(unknown): check that 1-arg constructors are explicit. # How to tell it's a constructor? # (handled in CheckForNonStandardConstructs for now) @@ -4861,9 +4935,14 @@ def CheckLanguage(filename, clean_lines, linenum, file_extension, % (match.group(1), match.group(2))) if Search(r'\busing namespace\b', line): - error(filename, linenum, 'build/namespaces', 5, - 'Do not use namespace using-directives. ' - 'Use using-declarations instead.') + if Search(r'\bliterals\b', line): + error(filename, linenum, 'build/namespaces_literals', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') + else: + error(filename, linenum, 'build/namespaces', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') # Detect variable-length arrays. match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line) @@ -4907,12 +4986,12 @@ def CheckLanguage(filename, clean_lines, linenum, file_extension, # Check for use of unnamed namespaces in header files. Registration # macros are typically OK, so we allow use of "namespace {" on lines # that end with backslashes. - if (file_extension == 'h' + if (file_extension in GetHeaderExtensions() and Search(r'\bnamespace\s*{', line) and line[-1] != '\\'): error(filename, linenum, 'build/namespaces', 4, 'Do not use unnamed namespaces in header files. See ' - 'http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' + 'https://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' ' for more information.') @@ -4933,9 +5012,13 @@ def CheckGlobalStatic(filename, clean_lines, linenum, error): # Check for people declaring static/global STL strings at the top level. # This is dangerous because the C++ language does not guarantee that - # globals with constructors are initialized before the first access. + # globals with constructors are initialized before the first access, and + # also because globals can be destroyed when some threads are still running. + # TODO(unknown): Generalize this to also find static unique_ptr instances. + # TODO(unknown): File bugs for clang-tidy to find these. match = Match( - r'((?:|static +)(?:|const +))string +([a-zA-Z0-9_:]+)\b(.*)', + r'((?:|static +)(?:|const +))(?::*std::)?string( +const)? +' + r'([a-zA-Z0-9_:]+)\b(.*)', line) # Remove false positives: @@ -4955,15 +5038,20 @@ def CheckGlobalStatic(filename, clean_lines, linenum, error): # matching identifiers. # string Class::operator*() if (match and - not Search(r'\bstring\b(\s+const)?\s*\*\s*(const\s+)?\w', line) and + not Search(r'\bstring\b(\s+const)?\s*[\*\&]\s*(const\s+)?\w', line) and not Search(r'\boperator\W', line) and - not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(3))): - error(filename, linenum, 'runtime/string', 4, - 'For a static/global string constant, use a C style string instead: ' - '"%schar %s[]".' % - (match.group(1), match.group(2))) + not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(4))): + if Search(r'\bconst\b', line): + error(filename, linenum, 'runtime/string', 4, + 'For a static/global string constant, use a C style string ' + 'instead: "%schar%s %s[]".' % + (match.group(1), match.group(2) or '', match.group(3))) + else: + error(filename, linenum, 'runtime/string', 4, + 'Static/global string variables are not permitted.') - if Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line): + if (Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line) or + Search(r'\b([A-Za-z0-9_]*_)\(CHECK_NOTNULL\(\1\)\)', line)): error(filename, linenum, 'runtime/init', 4, 'You seem to be initializing a member variable with itself.') @@ -5208,7 +5296,8 @@ def CheckForNonConstReference(filename, clean_lines, linenum, decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls): - if not Match(_RE_PATTERN_CONST_REF_PARAM, parameter): + if (not Match(_RE_PATTERN_CONST_REF_PARAM, parameter) and + not Match(_RE_PATTERN_REF_STREAM_PARAM, parameter)): error(filename, linenum, 'runtime/references', 2, 'Is this a non-const reference? ' 'If so, make const or use a pointer: ' + @@ -5231,7 +5320,7 @@ def CheckCasts(filename, clean_lines, linenum, error): # Parameterless conversion functions, such as bool(), are allowed as they are # probably a member operator declaration or default constructor. match = Search( - r'(\bnew\s+|\S<\s*(?:const\s+)?)?\b' + r'(\bnew\s+(?:const\s+)?|\S<\s*(?:const\s+)?)?\b' r'(int|float|double|bool|char|int32|uint32|int64|uint64)' r'(\([^)].*)', line) expecting_function = ExpectingFunctionArgs(clean_lines, linenum) @@ -5372,63 +5461,12 @@ def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error): if context.endswith(' operator++') or context.endswith(' operator--'): return False - # A single unnamed argument for a function tends to look like old - # style cast. If we see those, don't issue warnings for deprecated - # casts, instead issue warnings for unnamed arguments where - # appropriate. - # - # These are things that we want warnings for, since the style guide - # explicitly require all parameters to be named: - # Function(int); - # Function(int) { - # ConstMember(int) const; - # ConstMember(int) const { - # ExceptionMember(int) throw (...); - # ExceptionMember(int) throw (...) { - # PureVirtual(int) = 0; - # [](int) -> bool { - # - # These are functions of some sort, where the compiler would be fine - # if they had named parameters, but people often omit those - # identifiers to reduce clutter: - # (FunctionPointer)(int); - # (FunctionPointer)(int) = value; - # Function((function_pointer_arg)(int)) - # Function((function_pointer_arg)(int), int param) - # ; - # <(FunctionPointerTemplateArgument)(int)>; + # A single unnamed argument for a function tends to look like old style cast. + # If we see those, don't issue warnings for deprecated casts. remainder = line[match.end(0):] if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)', remainder): - # Looks like an unnamed parameter. - - # Don't warn on any kind of template arguments. - if Match(r'^\s*>', remainder): - return False - - # Don't warn on assignments to function pointers, but keep warnings for - # unnamed parameters to pure virtual functions. Note that this pattern - # will also pass on assignments of "0" to function pointers, but the - # preferred values for those would be "nullptr" or "NULL". - matched_zero = Match(r'^\s=\s*(\S+)\s*;', remainder) - if matched_zero and matched_zero.group(1) != '0': - return False - - # Don't warn on function pointer declarations. For this we need - # to check what came before the "(type)" string. - if Match(r'.*\)\s*$', line[0:match.start(0)]): - return False - - # Don't warn if the parameter is named with block comments, e.g.: - # Function(int /*unused_param*/); - raw_line = clean_lines.raw_lines[linenum] - if '/*' in raw_line: - return False - - # Passed all filters, issue warning here. - error(filename, linenum, 'readability/function', 3, - 'All parameters should be named in a function') - return True + return False # At this point, all that should be left is actual casts. error(filename, linenum, 'readability/casting', 4, @@ -5482,12 +5520,15 @@ def ExpectingFunctionArgs(clean_lines, linenum): ('', ('numeric_limits',)), ('', ('list',)), ('', ('map', 'multimap',)), - ('', ('allocator',)), + ('', ('allocator', 'make_shared', 'make_unique', 'shared_ptr', + 'unique_ptr', 'weak_ptr')), ('', ('queue', 'priority_queue',)), ('', ('set', 'multiset',)), ('', ('stack',)), ('', ('char_traits', 'basic_string',)), ('', ('tuple',)), + ('', ('unordered_map', 'unordered_multimap')), + ('', ('unordered_set', 'unordered_multiset')), ('', ('pair',)), ('', ('vector',)), @@ -5498,18 +5539,26 @@ def ExpectingFunctionArgs(clean_lines, linenum): ('', ('slist',)), ) -_RE_PATTERN_STRING = re.compile(r'\bstring\b') +_HEADERS_MAYBE_TEMPLATES = ( + ('', ('copy', 'max', 'min', 'min_element', 'sort', + 'transform', + )), + ('', ('forward', 'make_pair', 'move', 'swap')), + ) -_re_pattern_algorithm_header = [] -for _template in ('copy', 'max', 'min', 'min_element', 'sort', 'swap', - 'transform'): - # Match max(..., ...), max(..., ...), but not foo->max, foo.max or - # type::max(). - _re_pattern_algorithm_header.append( - (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), - _template, - '')) +_RE_PATTERN_STRING = re.compile(r'\bstring\b') +_re_pattern_headers_maybe_templates = [] +for _header, _templates in _HEADERS_MAYBE_TEMPLATES: + for _template in _templates: + # Match max(..., ...), max(..., ...), but not foo->max, foo.max or + # type::max(). + _re_pattern_headers_maybe_templates.append( + (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), + _template, + _header)) + +# Other scripts may reach in and modify this pattern. _re_pattern_templates = [] for _header, _templates in _HEADERS_CONTAINING_TEMPLATES: for _template in _templates: @@ -5540,7 +5589,7 @@ def FilesBelongToSameModule(filename_cc, filename_h): some false positives. This should be sufficiently rare in practice. Args: - filename_cc: is the path for the .cc file + filename_cc: is the path for the source (e.g. .cc) file filename_h: is the path for the header path Returns: @@ -5548,20 +5597,23 @@ def FilesBelongToSameModule(filename_cc, filename_h): bool: True if filename_cc and filename_h belong to the same module. string: the additional prefix needed to open the header file. """ + fileinfo_cc = FileInfo(filename_cc) + if not fileinfo_cc.Extension().lstrip('.') in GetNonHeaderExtensions(): + return (False, '') - if not filename_cc.endswith('.cc'): + fileinfo_h = FileInfo(filename_h) + if not fileinfo_h.Extension().lstrip('.') in GetHeaderExtensions(): return (False, '') - filename_cc = filename_cc[:-len('.cc')] - if filename_cc.endswith('_unittest'): - filename_cc = filename_cc[:-len('_unittest')] - elif filename_cc.endswith('_test'): - filename_cc = filename_cc[:-len('_test')] + + filename_cc = filename_cc[:-(len(fileinfo_cc.Extension()))] + matched_test_suffix = Search(_TEST_FILE_SUFFIX, fileinfo_cc.BaseName()) + if matched_test_suffix: + filename_cc = filename_cc[:-len(matched_test_suffix.group(1))] + filename_cc = filename_cc.replace('/public/', '/') filename_cc = filename_cc.replace('/internal/', '/') - if not filename_h.endswith('.h'): - return (False, '') - filename_h = filename_h[:-len('.h')] + filename_h = filename_h[:-(len(fileinfo_h.Extension()))] if filename_h.endswith('-inl'): filename_h = filename_h[:-len('-inl')] filename_h = filename_h.replace('/public/', '/') @@ -5622,7 +5674,7 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, required = {} # A map of header name to linenumber and the template entity. # Example of required: { '': (1219, 'less<>') } - for linenum in xrange(clean_lines.NumLines()): + for linenum in range(clean_lines.NumLines()): line = clean_lines.elided[linenum] if not line or line[0] == '#': continue @@ -5636,7 +5688,7 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, if prefix.endswith('std::') or not prefix.endswith('::'): required[''] = (linenum, 'string') - for pattern, template, header in _re_pattern_algorithm_header: + for pattern, template, header in _re_pattern_headers_maybe_templates: if pattern.search(line): required[header] = (linenum, template) @@ -5645,8 +5697,13 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, continue for pattern, template, header in _re_pattern_templates: - if pattern.search(line): - required[header] = (linenum, template) + matched = pattern.search(line) + if matched: + # Don't warn about IWYU in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[header] = (linenum, template) # The policy is that if you #include something in foo.h you don't need to # include it again in foo.cc. Here, we will look at possible includes. @@ -5671,7 +5728,7 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, # include_dict is modified during iteration, so we iterate over a copy of # the keys. - header_keys = include_dict.keys() + header_keys = list(include_dict.keys()) for header in header_keys: (same_module, common_path) = FilesBelongToSameModule(abs_filename, header) fullpath = common_path + header @@ -5683,11 +5740,13 @@ def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, # didn't include it in the .h file. # TODO(unknown): Do a better job of finding .h files so we are confident that # not having the .h file means there isn't one. - if filename.endswith('.cc') and not header_found: - return + if not header_found: + for extension in GetNonHeaderExtensions(): + if filename.endswith('.' + extension): + return # All the lines have been processed, report the errors found. - for required_header_unstripped in required: + for required_header_unstripped in sorted(required, key=required.__getitem__): template = required[required_header_unstripped][1] if required_header_unstripped.strip('<>"') not in include_dict: error(filename, required[required_header_unstripped][0], @@ -5719,31 +5778,6 @@ def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error): ' OR use pair directly OR if appropriate, construct a pair directly') -def CheckDefaultLambdaCaptures(filename, clean_lines, linenum, error): - """Check that default lambda captures are not used. - - Args: - filename: The name of the current file. - clean_lines: A CleansedLines instance containing the file. - linenum: The number of the line to check. - error: The function to call with any errors found. - """ - line = clean_lines.elided[linenum] - - # A lambda introducer specifies a default capture if it starts with "[=" - # or if it starts with "[&" _not_ followed by an identifier. - match = Match(r'^(.*)\[\s*(?:=|&[^\w])', line) - if match: - # Found a potential error, check what comes after the lambda-introducer. - # If it's not open parenthesis (for lambda-declarator) or open brace - # (for compound-statement), it's not a lambda. - line, _, pos = CloseExpression(clean_lines, linenum, len(match.group(1))) - if pos >= 0 and Match(r'^\s*[{(]', line[pos:]): - error(filename, linenum, 'build/c++11', - 4, # 4 = high confidence - 'Default lambda captures are an unapproved C++ feature.') - - def CheckRedundantVirtual(filename, clean_lines, linenum, error): """Check if line contains a redundant "virtual" function-specifier. @@ -5851,11 +5885,9 @@ def IsBlockInNameSpace(nesting_state, is_forward_declaration): Whether or not the new block is directly in a namespace. """ if is_forward_declaration: - if len(nesting_state.stack) >= 1 and ( - isinstance(nesting_state.stack[-1], _NamespaceInfo)): - return True - else: - return False + return len(nesting_state.stack) >= 1 and ( + isinstance(nesting_state.stack[-1], _NamespaceInfo)) + return (len(nesting_state.stack) > 1 and nesting_state.stack[-1].check_namespace_indentation and @@ -5905,7 +5937,7 @@ def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum, def ProcessLine(filename, file_extension, clean_lines, line, include_state, function_state, nesting_state, error, - extra_check_functions=[]): + extra_check_functions=None): """Processes a single line in the file. Args: @@ -5942,11 +5974,11 @@ def ProcessLine(filename, file_extension, clean_lines, line, CheckPosixThreading(filename, clean_lines, line, error) CheckInvalidIncrement(filename, clean_lines, line, error) CheckMakePairUsesDeduction(filename, clean_lines, line, error) - CheckDefaultLambdaCaptures(filename, clean_lines, line, error) CheckRedundantVirtual(filename, clean_lines, line, error) CheckRedundantOverrideOrFinal(filename, clean_lines, line, error) - for check_fn in extra_check_functions: - check_fn(filename, clean_lines, line, error) + if extra_check_functions: + for check_fn in extra_check_functions: + check_fn(filename, clean_lines, line, error) def FlagCxx11Features(filename, clean_lines, linenum, error): """Flag those c++11 features that we only allow in certain places. @@ -5959,8 +5991,14 @@ def FlagCxx11Features(filename, clean_lines, linenum, error): """ line = clean_lines.elided[linenum] - # Flag unapproved C++11 headers. include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line) + + # Flag unapproved C++ TR1 headers. + if include and include.group(1).startswith('tr1/'): + error(filename, linenum, 'build/c++tr1', 5, + ('C++ TR1 headers such as <%s> are unapproved.') % include.group(1)) + + # Flag unapproved C++11 headers. if include and include.group(1) in ('cfenv', 'condition_variable', 'fenv.h', @@ -5994,8 +6032,27 @@ def FlagCxx11Features(filename, clean_lines, linenum, error): 'they may let you use it.') % top_name) +def FlagCxx14Features(filename, clean_lines, linenum, error): + """Flag those C++14 features that we restrict. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line) + + # Flag unapproved C++14 headers. + if include and include.group(1) in ('scoped_allocator', 'shared_mutex'): + error(filename, linenum, 'build/c++14', 5, + ('<%s> is an unapproved C++14 header.') % include.group(1)) + + def ProcessFileData(filename, file_extension, lines, error, - extra_check_functions=[]): + extra_check_functions=None): """Performs lint checks and reports any errors to the given error function. Args: @@ -6019,14 +6076,14 @@ def ProcessFileData(filename, file_extension, lines, error, ResetNolintSuppressions() CheckForCopyright(filename, lines, error) - + ProcessGlobalSuppresions(lines) RemoveMultiLineComments(filename, lines, error) clean_lines = CleansedLines(lines) - if file_extension == 'h': + if file_extension in GetHeaderExtensions(): CheckForHeaderGuard(filename, clean_lines, error) - for line in xrange(clean_lines.NumLines()): + for line in range(clean_lines.NumLines()): ProcessLine(filename, file_extension, clean_lines, line, include_state, function_state, nesting_state, error, extra_check_functions) @@ -6034,9 +6091,9 @@ def ProcessFileData(filename, file_extension, lines, error, nesting_state.CheckCompletedBlocks(filename, error) CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error) - + # Check that the .cc file has included its header if it exists. - if file_extension == 'cc': + if _IsSourceExtension(file_extension): CheckHeaderFileIncluded(filename, include_state, error) # We check here rather than inside ProcessLine so that we see raw @@ -6092,36 +6149,56 @@ def ProcessConfigOverrides(filename): if base_name: pattern = re.compile(val) if pattern.match(base_name): - sys.stderr.write('Ignoring "%s": file excluded by "%s". ' - 'File path component "%s" matches ' - 'pattern "%s"\n' % - (filename, cfg_file, base_name, val)) + _cpplint_state.PrintInfo('Ignoring "%s": file excluded by ' + '"%s". File path component "%s" matches pattern "%s"\n' % + (filename, cfg_file, base_name, val)) return False elif name == 'linelength': global _line_length try: _line_length = int(val) except ValueError: - sys.stderr.write('Line length must be numeric.') + _cpplint_state.PrintError('Line length must be numeric.') + elif name == 'extensions': + global _valid_extensions + try: + extensions = [ext.strip() for ext in val.split(',')] + _valid_extensions = set(extensions) + except ValueError: + sys.stderr.write('Extensions should be a comma-separated list of values;' + 'for example: extensions=hpp,cpp\n' + 'This could not be parsed: "%s"' % (val,)) + elif name == 'headers': + global _header_extensions + try: + extensions = [ext.strip() for ext in val.split(',')] + _header_extensions = set(extensions) + except ValueError: + sys.stderr.write('Extensions should be a comma-separated list of values;' + 'for example: extensions=hpp,cpp\n' + 'This could not be parsed: "%s"' % (val,)) + elif name == 'root': + global _root + _root = val else: - sys.stderr.write( + _cpplint_state.PrintError( 'Invalid configuration option (%s) in file %s\n' % (name, cfg_file)) except IOError: - sys.stderr.write( + _cpplint_state.PrintError( "Skipping config file '%s': Can't open for reading\n" % cfg_file) keep_looking = False # Apply all the accumulated filters in reverse order (top-level directory # config options having the least priority). - for filter in reversed(cfg_filters): - _AddFilters(filter) + for cfg_filter in reversed(cfg_filters): + _AddFilters(cfg_filter) return True -def ProcessFile(filename, vlevel, extra_check_functions=[]): +def ProcessFile(filename, vlevel, extra_check_functions=None): """Does google-lint on a single file. Args: @@ -6170,7 +6247,7 @@ def ProcessFile(filename, vlevel, extra_check_functions=[]): lf_lines.append(linenum + 1) except IOError: - sys.stderr.write( + _cpplint_state.PrintError( "Skipping input '%s': Can't open for reading\n" % filename) _RestoreFilters() return @@ -6180,9 +6257,9 @@ def ProcessFile(filename, vlevel, extra_check_functions=[]): # When reading from stdin, the extension is unknown, so no cpplint tests # should rely on the extension. - if filename != '-' and file_extension not in _valid_extensions: - sys.stderr.write('Ignoring %s; not a valid file name ' - '(%s)\n' % (filename, ', '.join(_valid_extensions))) + if filename != '-' and file_extension not in GetAllExtensions(): + _cpplint_state.PrintError('Ignoring %s; not a valid file name ' + '(%s)\n' % (filename, ', '.join(GetAllExtensions()))) else: ProcessFileData(filename, file_extension, lines, Error, extra_check_functions) @@ -6205,7 +6282,7 @@ def ProcessFile(filename, vlevel, extra_check_functions=[]): Error(filename, linenum, 'whitespace/newline', 1, 'Unexpected \\r (^M) found; better to use only \\n') - sys.stderr.write('Done processing %s\n' % filename) + _cpplint_state.PrintInfo('Done processing %s\n' % filename) _RestoreFilters() @@ -6216,10 +6293,11 @@ def PrintUsage(message): message: The optional error message. """ sys.stderr.write(_USAGE) + if message: sys.exit('\nFATAL ERROR: ' + message) else: - sys.exit(1) + sys.exit(0) def PrintCategories(): @@ -6247,8 +6325,13 @@ def ParseArguments(args): 'counting=', 'filter=', 'root=', + 'repository=', 'linelength=', - 'extensions=']) + 'extensions=', + 'exclude=', + 'headers=', + 'quiet', + 'recursive']) except getopt.GetoptError: PrintUsage('Invalid arguments.') @@ -6256,13 +6339,15 @@ def ParseArguments(args): output_format = _OutputFormat() filters = '' counting_style = '' + recursive = False for (opt, val) in opts: if opt == '--help': PrintUsage(None) elif opt == '--output': - if val not in ('emacs', 'vs7', 'eclipse'): - PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.') + if val not in ('emacs', 'vs7', 'eclipse', 'junit'): + PrintUsage('The only allowed output formats are emacs, vs7, eclipse ' + 'and junit.') output_format = val elif opt == '--verbose': verbosity = int(val) @@ -6277,22 +6362,47 @@ def ParseArguments(args): elif opt == '--root': global _root _root = val + elif opt == '--repository': + global _repository + _repository = val elif opt == '--linelength': global _line_length try: - _line_length = int(val) + _line_length = int(val) except ValueError: - PrintUsage('Line length must be digits.') + PrintUsage('Line length must be digits.') + elif opt == '--exclude': + global _excludes + if not _excludes: + _excludes = set() + _excludes.update(glob.glob(val)) elif opt == '--extensions': global _valid_extensions try: - _valid_extensions = set(val.split(',')) + _valid_extensions = set(val.split(',')) except ValueError: PrintUsage('Extensions must be comma seperated list.') + elif opt == '--headers': + global _header_extensions + try: + _header_extensions = set(val.split(',')) + except ValueError: + PrintUsage('Extensions must be comma seperated list.') + elif opt == '--recursive': + recursive = True + elif opt == '--quiet': + global _quiet + _quiet = True if not filenames: PrintUsage('No files were specified.') + if recursive: + filenames = _ExpandDirectories(filenames) + + if _excludes: + filenames = _FilterExcludedFiles(filenames) + _SetOutputFormat(output_format) _SetVerboseLevel(verbosity) _SetFilters(filters) @@ -6300,21 +6410,63 @@ def ParseArguments(args): return filenames +def _ExpandDirectories(filenames): + """Searches a list of filenames and replaces directories in the list with + all files descending from those directories. Files with extensions not in + the valid extensions list are excluded. + + Args: + filenames: A list of files or directories + + Returns: + A list of all files that are members of filenames or descended from a + directory in filenames + """ + expanded = set() + for filename in filenames: + if not os.path.isdir(filename): + expanded.add(filename) + continue + + for root, _, files in os.walk(filename): + for loopfile in files: + fullname = os.path.join(root, loopfile) + if fullname.startswith('.' + os.path.sep): + fullname = fullname[len('.' + os.path.sep):] + expanded.add(fullname) + + filtered = [] + for filename in expanded: + if os.path.splitext(filename)[1][1:] in GetAllExtensions(): + filtered.append(filename) + + return filtered + +def _FilterExcludedFiles(filenames): + """Filters out files listed in the --exclude command line switch. File paths + in the switch are evaluated relative to the current working directory + """ + exclude_paths = [os.path.abspath(f) for f in _excludes] + return [f for f in filenames if os.path.abspath(f) not in exclude_paths] def main(): filenames = ParseArguments(sys.argv[1:]) + backup_err = sys.stderr + try: + # Change stderr to write with replacement characters so we don't die + # if we try to print something containing non-ASCII characters. + sys.stderr = codecs.StreamReader(sys.stderr, 'replace') - # Change stderr to write with replacement characters so we don't die - # if we try to print something containing non-ASCII characters. - sys.stderr = codecs.StreamReaderWriter(sys.stderr, - codecs.getreader('utf8'), - codecs.getwriter('utf8'), - 'replace') + _cpplint_state.ResetErrorCounts() + for filename in filenames: + ProcessFile(filename, _cpplint_state.verbose_level) + _cpplint_state.PrintErrorCounts() - _cpplint_state.ResetErrorCounts() - for filename in filenames: - ProcessFile(filename, _cpplint_state.verbose_level) - _cpplint_state.PrintErrorCounts() + if _cpplint_state.output_format == 'junit': + sys.stderr.write(_cpplint_state.FormatJUnitXML()) + + finally: + sys.stderr = backup_err sys.exit(_cpplint_state.error_count > 0) diff --git a/cpp/cmake_modules/SetupCxxFlags.cmake b/cpp/cmake_modules/SetupCxxFlags.cmake index 97aed6b274976..d901bde47c631 100644 --- a/cpp/cmake_modules/SetupCxxFlags.cmake +++ b/cpp/cmake_modules/SetupCxxFlags.cmake @@ -100,7 +100,7 @@ if ("${UPPERCASE_BUILD_WARNING_LEVEL}" STREQUAL "CHECKIN") -Wno-cast-align -Wno-vla-extension -Wno-shift-sign-overflow \ -Wno-used-but-marked-unused -Wno-missing-variable-declarations \ -Wno-gnu-zero-variadic-macro-arguments -Wconversion -Wno-sign-conversion \ --Wno-disabled-macro-expansion -Wno-shorten-64-to-32") +-Wno-disabled-macro-expansion") # Version numbers where warnings are introduced if ("${COMPILER_VERSION}" VERSION_GREATER "3.3") diff --git a/cpp/cmake_modules/ThirdpartyToolchain.cmake b/cpp/cmake_modules/ThirdpartyToolchain.cmake index 4f64434170655..69812b97cc770 100644 --- a/cpp/cmake_modules/ThirdpartyToolchain.cmake +++ b/cpp/cmake_modules/ThirdpartyToolchain.cmake @@ -926,8 +926,7 @@ if (ARROW_ORC) -DZLIB_HOME=${ZLIB_HOME}) ExternalProject_Add(orc_ep - GIT_REPOSITORY "https://github.com/apache/orc" - GIT_TAG ${ORC_VERSION} + URL "https://github.com/apache/orc/archive/${ORC_VERSION}.tar.gz" BUILD_BYPRODUCTS ${ORC_STATIC_LIB} CMAKE_ARGS ${ORC_CMAKE_ARGS}) diff --git a/cpp/src/arrow/CMakeLists.txt b/cpp/src/arrow/CMakeLists.txt index ad86256e0be34..74674bebb43be 100644 --- a/cpp/src/arrow/CMakeLists.txt +++ b/cpp/src/arrow/CMakeLists.txt @@ -153,6 +153,7 @@ install(FILES pretty_print.h record_batch.h status.h + stl.h table.h table_builder.h tensor.h @@ -183,6 +184,7 @@ ADD_ARROW_TEST(memory_pool-test) ADD_ARROW_TEST(pretty_print-test) ADD_ARROW_TEST(public-api-test) ADD_ARROW_TEST(status-test) +ADD_ARROW_TEST(stl-test) ADD_ARROW_TEST(type-test) ADD_ARROW_TEST(table-test) ADD_ARROW_TEST(table_builder-test) diff --git a/cpp/src/arrow/adapters/orc/adapter.cc b/cpp/src/arrow/adapters/orc/adapter.cc index 473c90f925124..f253808e34ffc 100644 --- a/cpp/src/arrow/adapters/orc/adapter.cc +++ b/cpp/src/arrow/adapters/orc/adapter.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" @@ -105,6 +106,8 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { return Status::OK(); } liborc::TypeKind kind = type->getKind(); + const int subtype_count = static_cast(type->getSubtypeCount()); + switch (kind) { case liborc::BOOLEAN: *out = boolean(); @@ -135,7 +138,7 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { *out = binary(); break; case liborc::CHAR: - *out = fixed_size_binary(type->getMaximumLength()); + *out = fixed_size_binary(static_cast(type->getMaximumLength())); break; case liborc::TIMESTAMP: *out = timestamp(TimeUnit::NANO); @@ -144,16 +147,18 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { *out = date32(); break; case liborc::DECIMAL: { - if (type->getPrecision() == 0) { + const int precision = static_cast(type->getPrecision()); + const int scale = static_cast(type->getScale()); + if (precision == 0) { // In HIVE 0.11/0.12 precision is set as 0, but means max precision *out = decimal(38, 6); } else { - *out = decimal(type->getPrecision(), type->getScale()); + *out = decimal(precision, scale); } break; } case liborc::LIST: { - if (type->getSubtypeCount() != 1) { + if (subtype_count != 1) { return Status::Invalid("Invalid Orc List type"); } std::shared_ptr elemtype; @@ -162,7 +167,7 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { break; } case liborc::MAP: { - if (type->getSubtypeCount() != 2) { + if (subtype_count != 2) { return Status::Invalid("Invalid Orc Map type"); } std::shared_ptr keytype; @@ -173,9 +178,8 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { break; } case liborc::STRUCT: { - int size = type->getSubtypeCount(); std::vector> fields; - for (int child = 0; child < size; ++child) { + for (int child = 0; child < subtype_count; ++child) { std::shared_ptr elemtype; RETURN_NOT_OK(GetArrowType(type->getSubtype(child), &elemtype)); std::string name = type->getFieldName(child); @@ -185,10 +189,9 @@ Status GetArrowType(const liborc::Type* type, std::shared_ptr* out) { break; } case liborc::UNION: { - int size = type->getSubtypeCount(); std::vector> fields; std::vector type_codes; - for (int child = 0; child < size; ++child) { + for (int child = 0; child < subtype_count; ++child) { std::shared_ptr elemtype; RETURN_NOT_OK(GetArrowType(type->getSubtype(child), &elemtype)); fields.push_back(field("_union_" + std::to_string(child), elemtype)); @@ -259,7 +262,7 @@ class ORCFileReader::Impl { "Only ORC files with a top-level struct " "can be handled"); } - int size = type.getSubtypeCount(); + int size = static_cast(type.getSubtypeCount()); std::vector> fields; for (int child = 0; child < size; ++child) { std::shared_ptr elemtype; @@ -449,7 +452,7 @@ class ORCFileReader::Impl { const liborc::Type* elemtype = type->getSubtype(0); const bool has_nulls = batch->hasNulls; - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { if (!has_nulls || batch->notNull[i]) { int64_t start = batch->offsets[i]; int64_t end = batch->offsets[i + 1]; @@ -474,7 +477,7 @@ class ORCFileReader::Impl { const liborc::Type* valtype = type->getSubtype(1); const bool has_nulls = batch->hasNulls; - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { RETURN_NOT_OK(list_builder->Append()); int64_t start = batch->offsets[i]; int64_t list_length = batch->offsets[i + 1] - start; @@ -516,7 +519,7 @@ class ORCFileReader::Impl { if (length == 0) { return Status::OK(); } - int start = builder->length(); + int64_t start = builder->length(); const uint8_t* valid_bytes = nullptr; if (batch->hasNulls) { @@ -540,7 +543,7 @@ class ORCFileReader::Impl { if (length == 0) { return Status::OK(); } - int start = builder->length(); + int64_t start = builder->length(); const uint8_t* valid_bytes = nullptr; if (batch->hasNulls) { @@ -551,7 +554,7 @@ class ORCFileReader::Impl { const int64_t* source = batch->data.data() + offset; uint8_t* target = reinterpret_cast(builder->data()->mutable_data()); - for (int i = 0; i < length; i++) { + for (int64_t i = 0; i < length; i++) { if (source[i]) { BitUtil::SetBit(target, start + i); } else { @@ -569,7 +572,7 @@ class ORCFileReader::Impl { if (length == 0) { return Status::OK(); } - int start = builder->length(); + int64_t start = builder->length(); const uint8_t* valid_bytes = nullptr; if (batch->hasNulls) { @@ -581,7 +584,7 @@ class ORCFileReader::Impl { const int64_t* nanos = batch->nanoseconds.data() + offset; int64_t* target = reinterpret_cast(builder->data()->mutable_data()); - for (int i = 0; i < length; i++) { + for (int64_t i = 0; i < length; i++) { // TODO: boundscheck this, as ORC supports higher resolution timestamps // than arrow for nanosecond resolution target[start + i] = seconds[i] * kOneSecondNanos + nanos[i]; @@ -596,9 +599,10 @@ class ORCFileReader::Impl { auto batch = static_cast(cbatch); const bool has_nulls = batch->hasNulls; - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { if (!has_nulls || batch->notNull[i]) { - RETURN_NOT_OK(builder->Append(batch->data[i], batch->length[i])); + RETURN_NOT_OK( + builder->Append(batch->data[i], static_cast(batch->length[i]))); } else { RETURN_NOT_OK(builder->AppendNull()); } @@ -612,7 +616,7 @@ class ORCFileReader::Impl { auto batch = static_cast(cbatch); const bool has_nulls = batch->hasNulls; - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { if (!has_nulls || batch->notNull[i]) { RETURN_NOT_OK(builder->Append(batch->data[i])); } else { @@ -629,7 +633,7 @@ class ORCFileReader::Impl { const bool has_nulls = cbatch->hasNulls; if (type->getPrecision() == 0 || type->getPrecision() > 18) { auto batch = static_cast(cbatch); - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { if (!has_nulls || batch->notNull[i]) { RETURN_NOT_OK(builder->Append( Decimal128(batch->values[i].getHighBits(), batch->values[i].getLowBits()))); @@ -639,7 +643,7 @@ class ORCFileReader::Impl { } } else { auto batch = static_cast(cbatch); - for (int i = offset; i < length + offset; i++) { + for (int64_t i = offset; i < length + offset; i++) { if (!has_nulls || batch->notNull[i]) { RETURN_NOT_OK(builder->Append(Decimal128(batch->values[i]))); } else { diff --git a/cpp/src/arrow/array-test.cc b/cpp/src/arrow/array-test.cc index 7ff3261ecba5e..c53da8591e94e 100644 --- a/cpp/src/arrow/array-test.cc +++ b/cpp/src/arrow/array-test.cc @@ -1155,6 +1155,45 @@ TEST_F(TestBinaryBuilder, TestScalarAppend) { } } +TEST_F(TestBinaryBuilder, TestCapacityReserve) { + vector strings = {"aaaaa", "bbbbbbbbbb", "ccccccccccccccc", "dddddddddd"}; + int N = static_cast(strings.size()); + int reps = 15; + int64_t length = 0; + int64_t capacity = 1000; + int64_t expected_capacity = BitUtil::RoundUpToMultipleOf64(capacity); + + ASSERT_OK(builder_->ReserveData(capacity)); + + ASSERT_EQ(length, builder_->value_data_length()); + ASSERT_EQ(expected_capacity, builder_->value_data_capacity()); + + for (int j = 0; j < reps; ++j) { + for (int i = 0; i < N; ++i) { + ASSERT_OK(builder_->Append(strings[i])); + length += static_cast(strings[i].size()); + + ASSERT_EQ(length, builder_->value_data_length()); + ASSERT_EQ(expected_capacity, builder_->value_data_capacity()); + } + } + + int extra_capacity = 500; + expected_capacity = BitUtil::RoundUpToMultipleOf64(length + extra_capacity); + + ASSERT_OK(builder_->ReserveData(extra_capacity)); + + ASSERT_EQ(length, builder_->value_data_length()); + ASSERT_EQ(expected_capacity, builder_->value_data_capacity()); + + Done(); + + ASSERT_EQ(reps * N, result_->length()); + ASSERT_EQ(0, result_->null_count()); + ASSERT_EQ(reps * 40, result_->value_data()->size()); + ASSERT_EQ(expected_capacity, result_->value_data()->capacity()); +} + TEST_F(TestBinaryBuilder, TestZeroLength) { // All buffers are null Done(); diff --git a/cpp/src/arrow/array.cc b/cpp/src/arrow/array.cc index 144fbcd05c205..3d72761ed18e5 100644 --- a/cpp/src/arrow/array.cc +++ b/cpp/src/arrow/array.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "arrow/buffer.h" #include "arrow/compare.h" diff --git a/cpp/src/arrow/array.h b/cpp/src/arrow/array.h index 0ae1ddd8ea221..f0a786131b2b5 100644 --- a/cpp/src/arrow/array.h +++ b/cpp/src/arrow/array.h @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/buffer.h" diff --git a/cpp/src/arrow/buffer-test.cc b/cpp/src/arrow/buffer-test.cc index 5fd2706f0466b..a24384a383395 100644 --- a/cpp/src/arrow/buffer-test.cc +++ b/cpp/src/arrow/buffer-test.cc @@ -52,6 +52,23 @@ TEST(TestBuffer, FromStdString) { ASSERT_EQ(static_cast(val.size()), buf.size()); } +TEST(TestBuffer, FromStdStringWithMemory) { + std::string expected = "hello, world"; + std::shared_ptr buf; + + { + std::string temp = "hello, world"; + ASSERT_OK(Buffer::FromString(temp, &buf)); + ASSERT_EQ(0, memcmp(buf->data(), temp.c_str(), temp.size())); + ASSERT_EQ(static_cast(temp.size()), buf->size()); + } + + // Now temp goes out of scope and we check if created buffer + // is still valid to make sure it actually owns its space + ASSERT_EQ(0, memcmp(buf->data(), expected.c_str(), expected.size())); + ASSERT_EQ(static_cast(expected.size()), buf->size()); +} + TEST(TestBuffer, Resize) { PoolBuffer buf; @@ -194,4 +211,29 @@ TEST(TestBuffer, SliceMutableBuffer) { ASSERT_TRUE(slice->Equals(expected)); } +TEST(TestBufferBuilder, ResizeReserve) { + const std::string data = "some data"; + auto data_ptr = data.c_str(); + + BufferBuilder builder; + + ASSERT_OK(builder.Append(data_ptr, 9)); + ASSERT_EQ(9, builder.length()); + + ASSERT_OK(builder.Resize(128)); + ASSERT_EQ(128, builder.capacity()); + + // Do not shrink to fit + ASSERT_OK(builder.Resize(64, false)); + ASSERT_EQ(128, builder.capacity()); + + // Shrink to fit + ASSERT_OK(builder.Resize(64)); + ASSERT_EQ(64, builder.capacity()); + + // Reserve elements + ASSERT_OK(builder.Reserve(60)); + ASSERT_EQ(128, builder.capacity()); +} + } // namespace arrow diff --git a/cpp/src/arrow/buffer.cc b/cpp/src/arrow/buffer.cc index 1b8e4375445bb..29e2c242a3f4a 100644 --- a/cpp/src/arrow/buffer.cc +++ b/cpp/src/arrow/buffer.cc @@ -58,6 +58,18 @@ bool Buffer::Equals(const Buffer& other) const { !memcmp(data_, other.data_, static_cast(size_)))); } +Status Buffer::FromString(const std::string& data, MemoryPool* pool, + std::shared_ptr* out) { + auto size = static_cast(data.size()); + RETURN_NOT_OK(AllocateBuffer(pool, size, out)); + std::copy(data.c_str(), data.c_str() + size, (*out)->mutable_data()); + return Status::OK(); +} + +Status Buffer::FromString(const std::string& data, std::shared_ptr* out) { + return FromString(data, default_memory_pool(), out); +} + PoolBuffer::PoolBuffer(MemoryPool* pool) : ResizableBuffer(nullptr, 0) { if (pool == nullptr) { pool = default_memory_pool(); diff --git a/cpp/src/arrow/buffer.h b/cpp/src/arrow/buffer.h index 450a4c78b5bbb..d12eeb4df9eed 100644 --- a/cpp/src/arrow/buffer.h +++ b/cpp/src/arrow/buffer.h @@ -25,6 +25,7 @@ #include #include +#include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/util/bit-util.h" #include "arrow/util/macros.h" @@ -32,13 +33,12 @@ namespace arrow { -class MemoryPool; - // ---------------------------------------------------------------------- // Buffer classes -/// Immutable API for a chunk of bytes which may or may not be owned by the -/// class instance. +/// \class Buffer +/// \brief Object containing a pointer to a piece of contiguous memory with a +/// particular size. Base class does not own its memory /// /// Buffers have two related notions of length: size and capacity. Size is /// the number of bytes that might have valid data. Capacity is the number @@ -97,6 +97,20 @@ class ARROW_EXPORT Buffer { Status Copy(const int64_t start, const int64_t nbytes, std::shared_ptr* out) const; + /// \brief Construct a new buffer that owns its memory from a std::string + /// + /// \param[in] data a std::string object + /// \param[in] pool a memory pool + /// \param[out] out the created buffer + /// + /// \return Status message + static Status FromString(const std::string& data, MemoryPool* pool, + std::shared_ptr* out); + + /// \brief Construct a new buffer that owns its memory from a std::string + /// using the default memory pool + static Status FromString(const std::string& data, std::shared_ptr* out); + int64_t capacity() const { return capacity_; } const uint8_t* data() const { return data_; } uint8_t* mutable_data() { return mutable_data_; } @@ -133,7 +147,8 @@ ARROW_EXPORT std::shared_ptr SliceMutableBuffer(const std::shared_ptr& buffer, const int64_t offset, const int64_t length); -/// A Buffer whose contents can be mutated. May or may not own its data. +/// \class MutableBuffer +/// \brief A Buffer whose contents can be mutated. May or may not own its data. class ARROW_EXPORT MutableBuffer : public Buffer { public: MutableBuffer(uint8_t* data, const int64_t size) : Buffer(data, size) { @@ -148,6 +163,8 @@ class ARROW_EXPORT MutableBuffer : public Buffer { MutableBuffer() : Buffer(NULLPTR, 0) {} }; +/// \class ResizableBuffer +/// \brief A mutable buffer that can be resized class ARROW_EXPORT ResizableBuffer : public MutableBuffer { public: /// Change buffer reported size to indicated size, allocating memory if @@ -190,13 +207,22 @@ class ARROW_EXPORT PoolBuffer : public ResizableBuffer { MemoryPool* pool_; }; +/// \class BufferBuilder +/// \brief A class for incrementally building a contiguous chunk of in-memory data class ARROW_EXPORT BufferBuilder { public: - explicit BufferBuilder(MemoryPool* pool) + explicit BufferBuilder(MemoryPool* pool ARROW_MEMORY_POOL_DEFAULT) : pool_(pool), data_(NULLPTR), capacity_(0), size_(0) {} - /// Resizes the buffer to the nearest multiple of 64 bytes per Layout.md - Status Resize(const int64_t elements) { + /// \brief Resizes the buffer to the nearest multiple of 64 bytes + /// + /// \param elements the new capacity of the of the builder. Will be rounded + /// up to a multiple of 64 bytes for padding + /// \param shrink_to_fit if new capacity smaller than existing size, + /// reallocate internal buffer. Set to false to avoid reallocations when + /// shrinking the builder + /// \return Status + Status Resize(const int64_t elements, bool shrink_to_fit = true) { // Resize(0) is a no-op if (elements == 0) { return Status::OK(); @@ -205,7 +231,7 @@ class ARROW_EXPORT BufferBuilder { buffer_ = std::make_shared(pool_); } int64_t old_capacity = capacity_; - RETURN_NOT_OK(buffer_->Resize(elements)); + RETURN_NOT_OK(buffer_->Resize(elements, shrink_to_fit)); capacity_ = buffer_->capacity(); data_ = buffer_->mutable_data(); if (capacity_ > old_capacity) { @@ -214,7 +240,14 @@ class ARROW_EXPORT BufferBuilder { return Status::OK(); } - Status Append(const uint8_t* data, int64_t length) { + /// \brief Ensure that builder can accommodate the additional number of bytes + /// without the need to perform allocations + /// + /// \param size number of additional bytes to make space for + /// \return Status + Status Reserve(const int64_t size) { return Resize(size_ + size, false); } + + Status Append(const void* data, int64_t length) { if (capacity_ < length + size_) { int64_t new_capacity = BitUtil::NextPower2(length + size_); RETURN_NOT_OK(Resize(new_capacity)); @@ -248,7 +281,7 @@ class ARROW_EXPORT BufferBuilder { } // Unsafe methods don't check existing size - void UnsafeAppend(const uint8_t* data, int64_t length) { + void UnsafeAppend(const void* data, int64_t length) { memcpy(data_ + size_, data, static_cast(length)); size_ += length; } @@ -314,6 +347,7 @@ class ARROW_EXPORT TypedBufferBuilder : public BufferBuilder { const T* data() const { return reinterpret_cast(data_); } int64_t length() const { return size_ / sizeof(T); } + int64_t capacity() const { return capacity_ / sizeof(T); } }; /// \brief Allocate a fixed size mutable buffer from a memory pool diff --git a/cpp/src/arrow/builder.cc b/cpp/src/arrow/builder.cc index de132b5f6a0d1..a740299dfe194 100644 --- a/cpp/src/arrow/builder.cc +++ b/cpp/src/arrow/builder.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include "arrow/array.h" @@ -1165,13 +1166,13 @@ Status ListBuilder::Init(int64_t elements) { DCHECK_LT(elements, std::numeric_limits::max()); RETURN_NOT_OK(ArrayBuilder::Init(elements)); // one more then requested for offsets - return offsets_builder_.Resize((elements + 1) * sizeof(int64_t)); + return offsets_builder_.Resize((elements + 1) * sizeof(int32_t)); } Status ListBuilder::Resize(int64_t capacity) { DCHECK_LT(capacity, std::numeric_limits::max()); // one more then requested for offsets - RETURN_NOT_OK(offsets_builder_.Resize((capacity + 1) * sizeof(int64_t))); + RETURN_NOT_OK(offsets_builder_.Resize((capacity + 1) * sizeof(int32_t))); return ArrayBuilder::Resize(capacity); } @@ -1216,16 +1217,26 @@ Status BinaryBuilder::Init(int64_t elements) { DCHECK_LT(elements, std::numeric_limits::max()); RETURN_NOT_OK(ArrayBuilder::Init(elements)); // one more then requested for offsets - return offsets_builder_.Resize((elements + 1) * sizeof(int64_t)); + return offsets_builder_.Resize((elements + 1) * sizeof(int32_t)); } Status BinaryBuilder::Resize(int64_t capacity) { DCHECK_LT(capacity, std::numeric_limits::max()); // one more then requested for offsets - RETURN_NOT_OK(offsets_builder_.Resize((capacity + 1) * sizeof(int64_t))); + RETURN_NOT_OK(offsets_builder_.Resize((capacity + 1) * sizeof(int32_t))); return ArrayBuilder::Resize(capacity); } +Status BinaryBuilder::ReserveData(int64_t elements) { + if (value_data_length() + elements > value_data_capacity()) { + if (value_data_length() + elements > std::numeric_limits::max()) { + return Status::Invalid("Cannot reserve capacity larger than 2^31 - 1 for binary"); + } + RETURN_NOT_OK(value_data_builder_.Reserve(elements)); + } + return Status::OK(); +} + Status BinaryBuilder::AppendNextOffset() { const int64_t num_bytes = value_data_builder_.length(); if (ARROW_PREDICT_FALSE(num_bytes > kMaximumCapacity)) { diff --git a/cpp/src/arrow/builder.h b/cpp/src/arrow/builder.h index ce7b8cd197da3..d1611f60cd924 100644 --- a/cpp/src/arrow/builder.h +++ b/cpp/src/arrow/builder.h @@ -682,10 +682,15 @@ class ARROW_EXPORT BinaryBuilder : public ArrayBuilder { Status Init(int64_t elements) override; Status Resize(int64_t capacity) override; + /// \brief Ensures there is enough allocated capacity to append the indicated + /// number of bytes to the value data buffer without additional allocations + Status ReserveData(int64_t elements); Status FinishInternal(std::shared_ptr* out) override; /// \return size of values buffer so far int64_t value_data_length() const { return value_data_builder_.length(); } + /// \return capacity of values buffer + int64_t value_data_capacity() const { return value_data_builder_.capacity(); } /// Temporary access to a value. /// diff --git a/cpp/src/arrow/compute/context.h b/cpp/src/arrow/compute/context.h index 051c91bf049fa..09838195a52ee 100644 --- a/cpp/src/arrow/compute/context.h +++ b/cpp/src/arrow/compute/context.h @@ -18,6 +18,8 @@ #ifndef ARROW_COMPUTE_CONTEXT_H #define ARROW_COMPUTE_CONTEXT_H +#include + #include "arrow/memory_pool.h" #include "arrow/status.h" #include "arrow/type_fwd.h" diff --git a/cpp/src/arrow/compute/kernels/hash.cc b/cpp/src/arrow/compute/kernels/hash.cc index 1face78bdebfb..acbf403987b40 100644 --- a/cpp/src/arrow/compute/kernels/hash.cc +++ b/cpp/src/arrow/compute/kernels/hash.cc @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/builder.h" @@ -406,12 +407,18 @@ class HashTableKernel> : public HashTable { } Status Append(const ArrayData& arr) override { + constexpr uint8_t empty_value = 0; if (!initialized_) { RETURN_NOT_OK(Init()); } const int32_t* offsets = GetValues(arr, 1); - const uint8_t* data = GetValues(arr, 2); + const uint8_t* data; + if (arr.buffers[2].get() == nullptr) { + data = &empty_value; + } else { + data = GetValues(arr, 2); + } auto action = static_cast(this); RETURN_NOT_OK(action->Reserve(arr.length)); diff --git a/cpp/src/arrow/compute/kernels/util-internal.cc b/cpp/src/arrow/compute/kernels/util-internal.cc index 28428bfcba6c6..0734365859b5a 100644 --- a/cpp/src/arrow/compute/kernels/util-internal.cc +++ b/cpp/src/arrow/compute/kernels/util-internal.cc @@ -17,6 +17,7 @@ #include "arrow/compute/kernels/util-internal.h" +#include #include #include "arrow/array.h" diff --git a/cpp/src/arrow/compute/kernels/util-internal.h b/cpp/src/arrow/compute/kernels/util-internal.h index 7633fed4a8fe7..2f611320a7687 100644 --- a/cpp/src/arrow/compute/kernels/util-internal.h +++ b/cpp/src/arrow/compute/kernels/util-internal.h @@ -18,6 +18,7 @@ #ifndef ARROW_COMPUTE_KERNELS_UTIL_INTERNAL_H #define ARROW_COMPUTE_KERNELS_UTIL_INTERNAL_H +#include #include #include "arrow/compute/kernel.h" diff --git a/cpp/src/arrow/io/hdfs-internal.cc b/cpp/src/arrow/io/hdfs-internal.cc index 9cd1c5052fe8d..efceb8ae6b403 100644 --- a/cpp/src/arrow/io/hdfs-internal.cc +++ b/cpp/src/arrow/io/hdfs-internal.cc @@ -147,7 +147,7 @@ static std::vector get_potential_libjvm_paths() { file_name = "jvm.dll"; #elif __APPLE__ search_prefixes = {""}; - search_suffixes = {"", "/jre/lib/server"}; + search_suffixes = {"", "/jre/lib/server", "/lib/server"}; file_name = "libjvm.dylib"; // SFrame uses /usr/libexec/java_home to find JAVA_HOME; for now we are @@ -175,7 +175,7 @@ static std::vector get_potential_libjvm_paths() { "/usr/lib/jvm/default", // alt centos "/usr/java/latest", // alt centos }; - search_suffixes = {"/jre/lib/amd64/server"}; + search_suffixes = {"", "/jre/lib/amd64/server", "/lib/amd64/server"}; file_name = "libjvm.so"; #endif // From direct environment variable @@ -310,6 +310,10 @@ void LibHdfsShim::BuilderSetKerbTicketCachePath(hdfsBuilder* bld, this->hdfsBuilderSetKerbTicketCachePath(bld, kerbTicketCachePath); } +void LibHdfsShim::BuilderSetForceNewInstance(hdfsBuilder* bld) { + this->hdfsBuilderSetForceNewInstance(bld); +} + hdfsFS LibHdfsShim::BuilderConnect(hdfsBuilder* bld) { return this->hdfsBuilderConnect(bld); } @@ -490,6 +494,7 @@ Status LibHdfsShim::GetRequiredSymbols() { GET_SYMBOL_REQUIRED(this, hdfsBuilderSetNameNodePort); GET_SYMBOL_REQUIRED(this, hdfsBuilderSetUserName); GET_SYMBOL_REQUIRED(this, hdfsBuilderSetKerbTicketCachePath); + GET_SYMBOL_REQUIRED(this, hdfsBuilderSetForceNewInstance); GET_SYMBOL_REQUIRED(this, hdfsBuilderConnect); GET_SYMBOL_REQUIRED(this, hdfsCreateDirectory); GET_SYMBOL_REQUIRED(this, hdfsDelete); diff --git a/cpp/src/arrow/io/hdfs-internal.h b/cpp/src/arrow/io/hdfs-internal.h index df925cf62823a..f0fce23c229ab 100644 --- a/cpp/src/arrow/io/hdfs-internal.h +++ b/cpp/src/arrow/io/hdfs-internal.h @@ -51,6 +51,7 @@ struct LibHdfsShim { void (*hdfsBuilderSetUserName)(hdfsBuilder* bld, const char* userName); void (*hdfsBuilderSetKerbTicketCachePath)(hdfsBuilder* bld, const char* kerbTicketCachePath); + void (*hdfsBuilderSetForceNewInstance)(hdfsBuilder* bld); hdfsFS (*hdfsBuilderConnect)(hdfsBuilder* bld); int (*hdfsDisconnect)(hdfsFS fs); @@ -95,6 +96,7 @@ struct LibHdfsShim { this->hdfsBuilderSetNameNodePort = nullptr; this->hdfsBuilderSetUserName = nullptr; this->hdfsBuilderSetKerbTicketCachePath = nullptr; + this->hdfsBuilderSetForceNewInstance = nullptr; this->hdfsBuilderConnect = nullptr; this->hdfsDisconnect = nullptr; this->hdfsOpenFile = nullptr; @@ -138,6 +140,8 @@ struct LibHdfsShim { void BuilderSetKerbTicketCachePath(hdfsBuilder* bld, const char* kerbTicketCachePath); + void BuilderSetForceNewInstance(hdfsBuilder* bld); + hdfsFS BuilderConnect(hdfsBuilder* bld); int Disconnect(hdfsFS fs); diff --git a/cpp/src/arrow/io/hdfs.cc b/cpp/src/arrow/io/hdfs.cc index 6e3e4a7a1c7e7..6c569ae1e2786 100644 --- a/cpp/src/arrow/io/hdfs.cc +++ b/cpp/src/arrow/io/hdfs.cc @@ -335,6 +335,7 @@ class HadoopFileSystem::HadoopFileSystemImpl { if (!config->kerb_ticket.empty()) { driver_->BuilderSetKerbTicketCachePath(builder, config->kerb_ticket.c_str()); } + driver_->BuilderSetForceNewInstance(builder); fs_ = driver_->BuilderConnect(builder); if (fs_ == nullptr) { diff --git a/cpp/src/arrow/io/io-hdfs-test.cc b/cpp/src/arrow/io/io-hdfs-test.cc index 5305b4774624d..f2ded6ff4b945 100644 --- a/cpp/src/arrow/io/io-hdfs-test.cc +++ b/cpp/src/arrow/io/io-hdfs-test.cc @@ -178,6 +178,21 @@ TYPED_TEST(TestHadoopFileSystem, ConnectsAgain) { ASSERT_OK(client->Disconnect()); } +TYPED_TEST(TestHadoopFileSystem, MultipleClients) { + SKIP_IF_NO_DRIVER(); + + std::shared_ptr client1; + std::shared_ptr client2; + ASSERT_OK(HadoopFileSystem::Connect(&this->conf_, &client1)); + ASSERT_OK(HadoopFileSystem::Connect(&this->conf_, &client2)); + ASSERT_OK(client1->Disconnect()); + + // client2 continues to function after equivalent client1 has shutdown + std::vector listing; + EXPECT_OK(client2->ListDirectory(this->scratch_dir_, &listing)); + ASSERT_OK(client2->Disconnect()); +} + TYPED_TEST(TestHadoopFileSystem, MakeDirectory) { SKIP_IF_NO_DRIVER(); diff --git a/cpp/src/arrow/ipc/feather.cc b/cpp/src/arrow/ipc/feather.cc index d3872503edf19..f440c19efe414 100644 --- a/cpp/src/arrow/ipc/feather.cc +++ b/cpp/src/arrow/ipc/feather.cc @@ -22,6 +22,7 @@ #include #include // IWYU pragma: keep #include +#include #include #include "flatbuffers/flatbuffers.h" diff --git a/cpp/src/arrow/ipc/json-internal.cc b/cpp/src/arrow/ipc/json-internal.cc index 4088a8f20e6a0..204bbc40b0072 100644 --- a/cpp/src/arrow/ipc/json-internal.cc +++ b/cpp/src/arrow/ipc/json-internal.cc @@ -866,7 +866,7 @@ static Status GetField(const rj::Value& obj, const DictionaryMemo* dictionary_me if (dictionary_memo != nullptr && it_dictionary != json_field.MemberEnd()) { // Field is dictionary encoded. We must have already RETURN_NOT_OBJECT("dictionary", it_dictionary, json_field); - int64_t dictionary_id; + int64_t dictionary_id = -1; bool is_ordered; std::shared_ptr index_type; RETURN_NOT_OK(ParseDictionary(it_dictionary->value.GetObject(), &dictionary_id, @@ -1346,7 +1346,7 @@ static Status ReadDictionaries(const rj::Value& doc, const DictionaryTypeMap& id for (const rj::Value& val : dictionary_array) { DCHECK(val.IsObject()); - int64_t dictionary_id; + int64_t dictionary_id = -1; std::shared_ptr dictionary; RETURN_NOT_OK( ReadDictionary(val.GetObject(), id_to_field, pool, &dictionary_id, &dictionary)); diff --git a/cpp/src/arrow/ipc/reader.cc b/cpp/src/arrow/ipc/reader.cc index ae0f8f39806b7..cc3b6e55783e3 100644 --- a/cpp/src/arrow/ipc/reader.cc +++ b/cpp/src/arrow/ipc/reader.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include // IWYU pragma: export diff --git a/cpp/src/arrow/pretty_print.cc b/cpp/src/arrow/pretty_print.cc index bd5f8ce10ea68..994f528ea4bad 100644 --- a/cpp/src/arrow/pretty_print.cc +++ b/cpp/src/arrow/pretty_print.cc @@ -15,6 +15,7 @@ // specific language governing permissions and limitations // under the License. +#include #include #include #include diff --git a/cpp/src/arrow/python/arrow_to_pandas.cc b/cpp/src/arrow/python/arrow_to_pandas.cc index e21bbda055953..5c8c970e1e058 100644 --- a/cpp/src/arrow/python/arrow_to_pandas.cc +++ b/cpp/src/arrow/python/arrow_to_pandas.cc @@ -963,7 +963,7 @@ class DatetimeTZBlock : public DatetimeBlock { class CategoricalBlock : public PandasBlock { public: explicit CategoricalBlock(PandasOptions options, MemoryPool* pool, int64_t num_rows) - : PandasBlock(options, num_rows, 1), pool_(pool) {} + : PandasBlock(options, num_rows, 1), pool_(pool), needs_copy_(false) {} Status Allocate() override { return Status::NotImplemented( @@ -996,14 +996,20 @@ class CategoricalBlock : public PandasBlock { return Status::OK(); }; - if (data.num_chunks() == 1 && indices_first.null_count() == 0) { + if (!needs_copy_ && data.num_chunks() == 1 && indices_first.null_count() == 0) { RETURN_NOT_OK(CheckIndices(indices_first, dict_arr_first.dictionary()->length())); RETURN_NOT_OK(AllocateNDArrayFromIndices(npy_type, indices_first)); } else { if (options_.zero_copy_only) { std::stringstream ss; - ss << "Needed to copy " << data.num_chunks() << " chunks with " - << indices_first.null_count() << " indices nulls, but zero_copy_only was True"; + if (needs_copy_) { + ss << "Need to allocate categorical memory, " + << "but only zero-copy conversions allowed."; + } else { + ss << "Needed to copy " << data.num_chunks() << " chunks with " + << indices_first.null_count() + << " indices nulls, but zero_copy_only was True"; + } return Status::Invalid(ss.str()); } RETURN_NOT_OK(AllocateNDArray(npy_type, 1)); @@ -1034,6 +1040,7 @@ class CategoricalBlock : public PandasBlock { std::shared_ptr converted_col; if (options_.strings_to_categorical && (col->type()->id() == Type::STRING || col->type()->id() == Type::BINARY)) { + needs_copy_ = true; compute::FunctionContext ctx(pool_); Datum out; @@ -1135,6 +1142,7 @@ class CategoricalBlock : public PandasBlock { MemoryPool* pool_; OwnedRef dictionary_; bool ordered_; + bool needs_copy_; }; Status MakeBlock(PandasOptions options, PandasBlock::type type, int64_t num_rows, diff --git a/cpp/src/arrow/python/arrow_to_python.cc b/cpp/src/arrow/python/arrow_to_python.cc index c060ab8bfd6db..c67e5410eb6ee 100644 --- a/cpp/src/arrow/python/arrow_to_python.cc +++ b/cpp/src/arrow/python/arrow_to_python.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include #include diff --git a/cpp/src/arrow/python/builtin_convert.cc b/cpp/src/arrow/python/builtin_convert.cc index cd88d557d4830..1b3c101758eec 100644 --- a/cpp/src/arrow/python/builtin_convert.cc +++ b/cpp/src/arrow/python/builtin_convert.cc @@ -23,6 +23,8 @@ #include #include #include +#include +#include #include "arrow/python/builtin_convert.h" @@ -32,6 +34,7 @@ #include "arrow/util/logging.h" #include "arrow/python/helpers.h" +#include "arrow/python/numpy_convert.h" #include "arrow/python/util/datetime.h" namespace arrow { @@ -93,6 +96,21 @@ class ScalarVisitor { ++binary_count_; } else if (PyUnicode_Check(obj)) { ++unicode_count_; + } else if (PyArray_CheckAnyScalarExact(obj)) { + std::shared_ptr type; + RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &type)); + if (is_integer(type->id())) { + ++int_count_; + } else if (is_floating(type->id())) { + ++float_count_; + } else if (type->id() == Type::TIMESTAMP) { + ++timestamp_count_; + } else { + std::ostringstream ss; + ss << "Found a NumPy scalar with Arrow dtype that we cannot handle: "; + ss << type->ToString(); + return Status::Invalid(ss.str()); + } } else { // TODO(wesm): accumulate error information somewhere static std::string supported_types = @@ -156,38 +174,26 @@ class SeqVisitor { Status Visit(PyObject* obj, int level = 0) { max_nesting_level_ = std::max(max_nesting_level_, level); - // Loop through either a sequence or an iterator. - if (PySequence_Check(obj)) { - Py_ssize_t size = PySequence_Size(obj); - for (int64_t i = 0; i < size; ++i) { - OwnedRef ref; - if (PyArray_Check(obj)) { - auto array = reinterpret_cast(obj); - auto ptr = reinterpret_cast(PyArray_GETPTR1(array, i)); - - ref.reset(PyArray_GETITEM(array, ptr)); - RETURN_IF_PYERROR(); - - RETURN_NOT_OK(VisitElem(ref, level)); - } else { - ref.reset(PySequence_GetItem(obj, i)); - RETURN_IF_PYERROR(); - RETURN_NOT_OK(VisitElem(ref, level)); - } - } - } else if (PyObject_HasAttrString(obj, "__iter__")) { - OwnedRef iter(PyObject_GetIter(obj)); - RETURN_IF_PYERROR(); + // Loop through a sequence + if (!PySequence_Check(obj)) + return Status::TypeError("Object is not a sequence or iterable"); - PyObject* item = NULLPTR; - while ((item = PyIter_Next(iter.obj()))) { + Py_ssize_t size = PySequence_Size(obj); + for (int64_t i = 0; i < size; ++i) { + OwnedRef ref; + if (PyArray_Check(obj)) { + auto array = reinterpret_cast(obj); + auto ptr = reinterpret_cast(PyArray_GETPTR1(array, i)); + + ref.reset(PyArray_GETITEM(array, ptr)); RETURN_IF_PYERROR(); - OwnedRef ref(item); + RETURN_NOT_OK(VisitElem(ref, level)); + } else { + ref.reset(PySequence_GetItem(obj, i)); + RETURN_IF_PYERROR(); RETURN_NOT_OK(VisitElem(ref, level)); } - } else { - return Status::TypeError("Object is not a sequence or iterable"); } return Status::OK(); } @@ -269,25 +275,45 @@ class SeqVisitor { } }; -Status InferArrowSize(PyObject* obj, int64_t* size) { +// Convert *obj* to a sequence if necessary +// Fill *size* to its length. If >= 0 on entry, *size* is an upper size +// bound that may lead to truncation. +Status ConvertToSequenceAndInferSize(PyObject* obj, PyObject** seq, int64_t* size) { if (PySequence_Check(obj)) { - *size = static_cast(PySequence_Size(obj)); - } else if (PyObject_HasAttrString(obj, "__iter__")) { + // obj is already a sequence + int64_t real_size = static_cast(PySequence_Size(obj)); + if (*size < 0) { + *size = real_size; + } else { + *size = std::min(real_size, *size); + } + Py_INCREF(obj); + *seq = obj; + } else if (*size < 0) { + // unknown size, exhaust iterator + *seq = PySequence_List(obj); + RETURN_IF_PYERROR(); + *size = static_cast(PyList_GET_SIZE(*seq)); + } else { + // size is known but iterator could be infinite + Py_ssize_t i, n = *size; PyObject* iter = PyObject_GetIter(obj); + RETURN_IF_PYERROR(); OwnedRef iter_ref(iter); - *size = 0; - PyObject* item; - while ((item = PyIter_Next(iter))) { - OwnedRef item_ref(item); - *size += 1; + PyObject* lst = PyList_New(n); + RETURN_IF_PYERROR(); + for (i = 0; i < n; i++) { + PyObject* item = PyIter_Next(iter); + if (!item) break; + PyList_SET_ITEM(lst, i, item); } - } else { - return Status::TypeError("Object is not a sequence or iterable"); - } - if (PyErr_Occurred()) { - // Not a sequence - PyErr_Clear(); - return Status::TypeError("Object is not a sequence or iterable"); + // Shrink list if len(iterator) < size + if (i < n && PyList_SetSlice(lst, i, n, NULL)) { + Py_DECREF(lst); + return Status::UnknownError("failed to resize list"); + } + *seq = lst; + *size = std::min(i, *size); } return Status::OK(); } @@ -309,7 +335,10 @@ Status InferArrowType(PyObject* obj, std::shared_ptr* out_type) { Status InferArrowTypeAndSize(PyObject* obj, int64_t* size, std::shared_ptr* out_type) { - RETURN_NOT_OK(InferArrowSize(obj, size)); + if (!PySequence_Check(obj)) { + return Status::TypeError("Object is not a sequence"); + } + *size = static_cast(PySequence_Size(obj)); // For 0-length sequences, refuse to guess if (*size == 0) { @@ -329,7 +358,11 @@ class SeqConverter { return Status::OK(); } - virtual Status AppendData(PyObject* seq, int64_t size) = 0; + // Append a single (non-sequence) Python datum to the underlying builder + virtual Status AppendSingle(PyObject* obj) = 0; + + // Append the contents of a Python sequence to the underlying builder + virtual Status AppendMultiple(PyObject* seq, int64_t size) = 0; virtual ~SeqConverter() = default; @@ -350,66 +383,57 @@ class TypedConverter : public SeqConverter { BuilderType* typed_builder_; }; +// We use the CRTP trick here to devirtualize the AppendItem() and AppendNull() +// method calls. template class TypedConverterVisitor : public TypedConverter { public: - Status AppendData(PyObject* obj, int64_t size) override { + Status AppendSingle(PyObject* obj) override { + if (obj == Py_None) { + return static_cast(this)->AppendNull(); + } else { + return static_cast(this)->AppendItem(obj); + } + } + + Status AppendMultiple(PyObject* obj, int64_t size) override { /// Ensure we've allocated enough space RETURN_NOT_OK(this->typed_builder_->Reserve(size)); // Iterate over the items adding each one if (PySequence_Check(obj)) { for (int64_t i = 0; i < size; ++i) { OwnedRef ref(PySequence_GetItem(obj, i)); - if (ref.obj() == Py_None) { - RETURN_NOT_OK(this->typed_builder_->AppendNull()); - } else { - RETURN_NOT_OK(static_cast(this)->AppendItem(ref)); - } - } - } else if (PyObject_HasAttrString(obj, "__iter__")) { - PyObject* iter = PyObject_GetIter(obj); - OwnedRef iter_ref(iter); - PyObject* item; - int64_t i = 0; - // To allow people with long generators to only convert a subset, stop - // consuming at size. - while ((item = PyIter_Next(iter)) && i < size) { - OwnedRef ref(item); - if (ref.obj() == Py_None) { - RETURN_NOT_OK(this->typed_builder_->AppendNull()); - } else { - RETURN_NOT_OK(static_cast(this)->AppendItem(ref)); - } - ++i; - } - if (size != i) { - RETURN_NOT_OK(this->typed_builder_->Resize(i)); + RETURN_NOT_OK(static_cast(this)->AppendSingle(ref.obj())); } } else { - return Status::TypeError("Object is not a sequence or iterable"); + return Status::TypeError("Object is not a sequence"); } return Status::OK(); } + + // Append a missing item (default implementation) + Status AppendNull() { return this->typed_builder_->AppendNull(); } }; class NullConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { return Status::Invalid("NullConverter: passed non-None value"); } }; class BoolConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - return typed_builder_->Append(item.obj() == Py_True); - } + // Append a non-missing item + Status AppendItem(PyObject* obj) { return typed_builder_->Append(obj == Py_True); } }; class Int8Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max() || val < std::numeric_limits::min())) { @@ -424,8 +448,9 @@ class Int8Converter : public TypedConverterVisitor { class Int16Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max() || val < std::numeric_limits::min())) { @@ -440,8 +465,9 @@ class Int16Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max() || val < std::numeric_limits::min())) { @@ -456,8 +482,9 @@ class Int32Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); return typed_builder_->Append(val); } @@ -465,8 +492,9 @@ class Int64Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max())) { @@ -480,8 +508,9 @@ class UInt8Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max())) { @@ -495,8 +524,9 @@ class UInt16Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); if (ARROW_PREDICT_FALSE(val > std::numeric_limits::max())) { @@ -510,8 +540,9 @@ class UInt32Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - const auto val = static_cast(PyLong_AsLongLong(item.obj())); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + const auto val = static_cast(PyLong_AsUnsignedLongLong(obj)); RETURN_IF_PYERROR(); return typed_builder_->Append(val); } @@ -519,13 +550,14 @@ class UInt64Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { int32_t t; - if (PyDate_Check(item.obj())) { - auto pydate = reinterpret_cast(item.obj()); + if (PyDate_Check(obj)) { + auto pydate = reinterpret_cast(obj); t = static_cast(PyDate_to_s(pydate)); } else { - const auto casted_val = static_cast(PyLong_AsLongLong(item.obj())); + const auto casted_val = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); if (casted_val > std::numeric_limits::max()) { return Status::Invalid("Integer as date32 larger than INT32_MAX"); @@ -538,13 +570,14 @@ class Date32Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { int64_t t; - if (PyDate_Check(item.obj())) { - auto pydate = reinterpret_cast(item.obj()); + if (PyDate_Check(obj)) { + auto pydate = reinterpret_cast(obj); t = PyDate_to_ms(pydate); } else { - t = static_cast(PyLong_AsLongLong(item.obj())); + t = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); } return typed_builder_->Append(t); @@ -556,10 +589,11 @@ class TimestampConverter public: explicit TimestampConverter(TimeUnit::type unit) : unit_(unit) {} - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { int64_t t; - if (PyDateTime_Check(item.obj())) { - auto pydatetime = reinterpret_cast(item.obj()); + if (PyDateTime_Check(obj)) { + auto pydatetime = reinterpret_cast(obj); switch (unit_) { case TimeUnit::SECOND: @@ -574,9 +608,28 @@ class TimestampConverter case TimeUnit::NANO: t = PyDateTime_to_ns(pydatetime); break; + default: + return Status::UnknownError("Invalid time unit"); } + } else if (PyArray_CheckAnyScalarExact(obj)) { + // numpy.datetime64 + std::shared_ptr type; + RETURN_NOT_OK(NumPyDtypeToArrow(PyArray_DescrFromScalar(obj), &type)); + if (type->id() != Type::TIMESTAMP) { + std::ostringstream ss; + ss << "Expected np.datetime64 but got: "; + ss << type->ToString(); + return Status::Invalid(ss.str()); + } + const TimestampType& ttype = static_cast(*type); + if (unit_ != ttype.unit()) { + return Status::NotImplemented( + "Cannot convert NumPy datetime64 objects with differing unit"); + } + + t = reinterpret_cast(obj)->obval; } else { - t = static_cast(PyLong_AsLongLong(item.obj())); + t = static_cast(PyLong_AsLongLong(obj)); RETURN_IF_PYERROR(); } return typed_builder_->Append(t); @@ -586,10 +639,21 @@ class TimestampConverter TimeUnit::type unit_; }; +class Float32Converter : public TypedConverterVisitor { + public: + // Append a non-missing item + Status AppendItem(PyObject* obj) { + float val = static_cast(PyFloat_AsDouble(obj)); + RETURN_IF_PYERROR(); + return typed_builder_->Append(val); + } +}; + class DoubleConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { - double val = PyFloat_AsDouble(item.obj()); + // Append a non-missing item + Status AppendItem(PyObject* obj) { + double val = PyFloat_AsDouble(obj); RETURN_IF_PYERROR(); return typed_builder_->Append(val); } @@ -597,22 +661,23 @@ class DoubleConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { PyObject* bytes_obj; const char* bytes; Py_ssize_t length; OwnedRef tmp; - if (PyUnicode_Check(item.obj())) { - tmp.reset(PyUnicode_AsUTF8String(item.obj())); + if (PyUnicode_Check(obj)) { + tmp.reset(PyUnicode_AsUTF8String(obj)); RETURN_IF_PYERROR(); bytes_obj = tmp.obj(); - } else if (PyBytes_Check(item.obj())) { - bytes_obj = item.obj(); + } else if (PyBytes_Check(obj)) { + bytes_obj = obj; } else { std::stringstream ss; ss << "Error converting to Binary type: "; - RETURN_NOT_OK(InvalidConversion(item.obj(), "bytes", &ss)); + RETURN_NOT_OK(InvalidConversion(obj, "bytes", &ss)); return Status::Invalid(ss.str()); } // No error checking @@ -625,22 +690,23 @@ class BytesConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { PyObject* bytes_obj; OwnedRef tmp; Py_ssize_t expected_length = std::dynamic_pointer_cast(typed_builder_->type()) ->byte_width(); - if (PyUnicode_Check(item.obj())) { - tmp.reset(PyUnicode_AsUTF8String(item.obj())); + if (PyUnicode_Check(obj)) { + tmp.reset(PyUnicode_AsUTF8String(obj)); RETURN_IF_PYERROR(); bytes_obj = tmp.obj(); - } else if (PyBytes_Check(item.obj())) { - bytes_obj = item.obj(); + } else if (PyBytes_Check(obj)) { + bytes_obj = obj; } else { std::stringstream ss; ss << "Error converting to FixedSizeBinary type: "; - RETURN_NOT_OK(InvalidConversion(item.obj(), "bytes", &ss)); + RETURN_NOT_OK(InvalidConversion(obj, "bytes", &ss)); return Status::Invalid(ss.str()); } // No error checking @@ -652,13 +718,13 @@ class FixedWidthBytesConverter class UTF8Converter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { PyObject* bytes_obj; OwnedRef tmp; const char* bytes; Py_ssize_t length; - PyObject* obj = item.obj(); if (PyBytes_Check(obj)) { tmp.reset( PyUnicode_FromStringAndSize(PyBytes_AS_STRING(obj), PyBytes_GET_SIZE(obj))); @@ -687,73 +753,114 @@ class ListConverter : public TypedConverterVisitor { public: Status Init(ArrayBuilder* builder) override; - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { RETURN_NOT_OK(typed_builder_->Append()); - PyObject* item_obj = item.obj(); - const auto list_size = static_cast(PySequence_Size(item_obj)); - return value_converter_->AppendData(item_obj, list_size); + const auto list_size = static_cast(PySequence_Size(obj)); + return value_converter_->AppendMultiple(obj, list_size); } protected: - std::shared_ptr value_converter_; + std::unique_ptr value_converter_; +}; + +class StructConverter : public TypedConverterVisitor { + public: + Status Init(ArrayBuilder* builder) override; + + // Append a non-missing item + Status AppendItem(PyObject* obj) { + RETURN_NOT_OK(typed_builder_->Append()); + if (!PyDict_Check(obj)) { + return Status::TypeError("dict value expected for struct type"); + } + // NOTE we're ignoring any extraneous dict items + for (int i = 0; i < num_fields_; i++) { + PyObject* nameobj = PyList_GET_ITEM(field_name_list_.obj(), i); + PyObject* valueobj = PyDict_GetItem(obj, nameobj); // borrowed + RETURN_IF_PYERROR(); + RETURN_NOT_OK(value_converters_[i]->AppendSingle(valueobj ? valueobj : Py_None)); + } + + return Status::OK(); + } + + // Append a missing item + Status AppendNull() { + RETURN_NOT_OK(typed_builder_->AppendNull()); + // Need to also insert a missing item on all child builders + // (compare with ListConverter) + for (int i = 0; i < num_fields_; i++) { + RETURN_NOT_OK(value_converters_[i]->AppendSingle(Py_None)); + } + return Status::OK(); + } + + protected: + std::vector> value_converters_; + OwnedRef field_name_list_; + int num_fields_; }; class DecimalConverter : public TypedConverterVisitor { public: - Status AppendItem(const OwnedRef& item) { + // Append a non-missing item + Status AppendItem(PyObject* obj) { /// TODO(phillipc): Check for nan? Decimal128 value; const auto& type = static_cast(*typed_builder_->type()); - RETURN_NOT_OK(internal::DecimalFromPythonDecimal(item.obj(), type, &value)); + RETURN_NOT_OK(internal::DecimalFromPythonDecimal(obj, type, &value)); return typed_builder_->Append(value); } }; // Dynamic constructor for sequence converters -std::shared_ptr GetConverter(const std::shared_ptr& type) { +std::unique_ptr GetConverter(const std::shared_ptr& type) { switch (type->id()) { case Type::NA: - return std::make_shared(); + return std::unique_ptr(new NullConverter); case Type::BOOL: - return std::make_shared(); + return std::unique_ptr(new BoolConverter); case Type::INT8: - return std::make_shared(); + return std::unique_ptr(new Int8Converter); case Type::INT16: - return std::make_shared(); + return std::unique_ptr(new Int16Converter); case Type::INT32: - return std::make_shared(); + return std::unique_ptr(new Int32Converter); case Type::INT64: - return std::make_shared(); + return std::unique_ptr(new Int64Converter); case Type::UINT8: - return std::make_shared(); + return std::unique_ptr(new UInt8Converter); case Type::UINT16: - return std::make_shared(); + return std::unique_ptr(new UInt16Converter); case Type::UINT32: - return std::make_shared(); + return std::unique_ptr(new UInt32Converter); case Type::UINT64: - return std::make_shared(); + return std::unique_ptr(new UInt64Converter); case Type::DATE32: - return std::make_shared(); + return std::unique_ptr(new Date32Converter); case Type::DATE64: - return std::make_shared(); + return std::unique_ptr(new Date64Converter); case Type::TIMESTAMP: - return std::make_shared( - static_cast(*type).unit()); + return std::unique_ptr( + new TimestampConverter(static_cast(*type).unit())); + case Type::FLOAT: + return std::unique_ptr(new Float32Converter); case Type::DOUBLE: - return std::make_shared(); + return std::unique_ptr(new DoubleConverter); case Type::BINARY: - return std::make_shared(); + return std::unique_ptr(new BytesConverter); case Type::FIXED_SIZE_BINARY: - return std::make_shared(); + return std::unique_ptr(new FixedWidthBytesConverter); case Type::STRING: - return std::make_shared(); + return std::unique_ptr(new UTF8Converter); case Type::LIST: - return std::make_shared(); - case Type::DECIMAL: { - return std::make_shared(); - } + return std::unique_ptr(new ListConverter); case Type::STRUCT: + return std::unique_ptr(new StructConverter); + case Type::DECIMAL: + return std::unique_ptr(new DecimalConverter); default: return nullptr; } @@ -772,51 +879,102 @@ Status ListConverter::Init(ArrayBuilder* builder) { return value_converter_->Init(typed_builder_->value_builder()); } +Status StructConverter::Init(ArrayBuilder* builder) { + builder_ = builder; + typed_builder_ = static_cast(builder); + StructType* struct_type = static_cast(builder->type().get()); + + num_fields_ = typed_builder_->num_fields(); + DCHECK_EQ(num_fields_, struct_type->num_children()); + + field_name_list_.reset(PyList_New(num_fields_)); + RETURN_IF_PYERROR(); + + // Initialize the child converters and field names + for (int i = 0; i < num_fields_; i++) { + const std::string& field_name(struct_type->child(i)->name()); + std::shared_ptr field_type(struct_type->child(i)->type()); + + auto value_converter = GetConverter(field_type); + if (value_converter == nullptr) { + return Status::NotImplemented("value type not implemented"); + } + RETURN_NOT_OK(value_converter->Init(typed_builder_->field_builder(i))); + value_converters_.push_back(std::move(value_converter)); + + // Store the field name as a PyObject, for dict matching + PyObject* nameobj = + PyUnicode_FromStringAndSize(field_name.c_str(), field_name.size()); + RETURN_IF_PYERROR(); + PyList_SET_ITEM(field_name_list_.obj(), i, nameobj); + } + + return Status::OK(); +} + Status AppendPySequence(PyObject* obj, int64_t size, const std::shared_ptr& type, ArrayBuilder* builder) { PyDateTime_IMPORT; - std::shared_ptr converter = GetConverter(type); + auto converter = GetConverter(type); if (converter == nullptr) { std::stringstream ss; ss << "No type converter implemented for " << type->ToString(); return Status::NotImplemented(ss.str()); } RETURN_NOT_OK(converter->Init(builder)); - return converter->AppendData(obj, size); + return converter->AppendMultiple(obj, size); } -Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out) { +static Status ConvertPySequenceReal(PyObject* obj, int64_t size, + const std::shared_ptr* type, + MemoryPool* pool, std::shared_ptr* out) { PyAcquireGIL lock; - std::shared_ptr type; - int64_t size; - RETURN_NOT_OK(InferArrowTypeAndSize(obj, &size, &type)); - return ConvertPySequence(obj, pool, out, type, size); -} -Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out, - const std::shared_ptr& type, int64_t size) { - PyAcquireGIL lock; + PyObject* seq; + ScopedRef tmp_seq_nanny; + + std::shared_ptr real_type; + + RETURN_NOT_OK(ConvertToSequenceAndInferSize(obj, &seq, &size)); + tmp_seq_nanny.reset(seq); + if (type == nullptr) { + RETURN_NOT_OK(InferArrowType(seq, &real_type)); + } else { + real_type = *type; + } + DCHECK_GE(size, 0); + // Handle NA / NullType case - if (type->id() == Type::NA) { + if (real_type->id() == Type::NA) { out->reset(new NullArray(size)); return Status::OK(); } // Give the sequence converter an array builder std::unique_ptr builder; - RETURN_NOT_OK(MakeBuilder(pool, type, &builder)); - RETURN_NOT_OK(AppendPySequence(obj, size, type, builder.get())); + RETURN_NOT_OK(MakeBuilder(pool, real_type, &builder)); + RETURN_NOT_OK(AppendPySequence(seq, size, real_type, builder.get())); return builder->Finish(out); } -Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out, - const std::shared_ptr& type) { - int64_t size; - { - PyAcquireGIL lock; - RETURN_NOT_OK(InferArrowSize(obj, &size)); - } - return ConvertPySequence(obj, pool, out, type, size); +Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out) { + return ConvertPySequenceReal(obj, -1, nullptr, pool, out); +} + +Status ConvertPySequence(PyObject* obj, const std::shared_ptr& type, + MemoryPool* pool, std::shared_ptr* out) { + return ConvertPySequenceReal(obj, -1, &type, pool, out); +} + +Status ConvertPySequence(PyObject* obj, int64_t size, MemoryPool* pool, + std::shared_ptr* out) { + return ConvertPySequenceReal(obj, size, nullptr, pool, out); +} + +Status ConvertPySequence(PyObject* obj, int64_t size, + const std::shared_ptr& type, MemoryPool* pool, + std::shared_ptr* out) { + return ConvertPySequenceReal(obj, size, &type, pool, out); } Status CheckPythonBytesAreFixedLength(PyObject* obj, Py_ssize_t expected_length) { diff --git a/cpp/src/arrow/python/builtin_convert.h b/cpp/src/arrow/python/builtin_convert.h index cde7a1bd4cfdc..4bd3f08edf162 100644 --- a/cpp/src/arrow/python/builtin_convert.h +++ b/cpp/src/arrow/python/builtin_convert.h @@ -39,11 +39,11 @@ class Status; namespace py { +// These three functions take a sequence input, not arbitrary iterables ARROW_EXPORT arrow::Status InferArrowType(PyObject* obj, std::shared_ptr* out_type); ARROW_EXPORT arrow::Status InferArrowTypeAndSize( PyObject* obj, int64_t* size, std::shared_ptr* out_type); -ARROW_EXPORT arrow::Status InferArrowSize(PyObject* obj, int64_t* size); ARROW_EXPORT arrow::Status AppendPySequence(PyObject* obj, int64_t size, const std::shared_ptr& type, @@ -53,15 +53,21 @@ ARROW_EXPORT arrow::Status AppendPySequence(PyObject* obj, int64_t size, ARROW_EXPORT Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out); -// Size inference +// Type inference only ARROW_EXPORT -Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out, - const std::shared_ptr& type); +Status ConvertPySequence(PyObject* obj, int64_t size, MemoryPool* pool, + std::shared_ptr* out); + +// Size inference only +ARROW_EXPORT +Status ConvertPySequence(PyObject* obj, const std::shared_ptr& type, + MemoryPool* pool, std::shared_ptr* out); // No inference ARROW_EXPORT -Status ConvertPySequence(PyObject* obj, MemoryPool* pool, std::shared_ptr* out, - const std::shared_ptr& type, int64_t size); +Status ConvertPySequence(PyObject* obj, int64_t size, + const std::shared_ptr& type, MemoryPool* pool, + std::shared_ptr* out); ARROW_EXPORT Status InvalidConversion(PyObject* obj, const std::string& expected_type_name, diff --git a/cpp/src/arrow/python/io.cc b/cpp/src/arrow/python/io.cc index cc3892928c455..2cff046085e69 100644 --- a/cpp/src/arrow/python/io.cc +++ b/cpp/src/arrow/python/io.cc @@ -19,12 +19,14 @@ #include #include +#include #include #include #include "arrow/io/memory.h" #include "arrow/memory_pool.h" #include "arrow/status.h" +#include "arrow/util/logging.h" #include "arrow/python/common.h" @@ -132,12 +134,14 @@ Status PyReadableFile::Tell(int64_t* position) const { Status PyReadableFile::Read(int64_t nbytes, int64_t* bytes_read, void* out) { PyAcquireGIL lock; - PyObject* bytes_obj; + + PyObject* bytes_obj = NULL; ARROW_RETURN_NOT_OK(file_->Read(nbytes, &bytes_obj)); + DCHECK(bytes_obj != NULL); *bytes_read = PyBytes_GET_SIZE(bytes_obj); std::memcpy(out, PyBytes_AS_STRING(bytes_obj), *bytes_read); - Py_DECREF(bytes_obj); + Py_XDECREF(bytes_obj); return Status::OK(); } @@ -145,11 +149,12 @@ Status PyReadableFile::Read(int64_t nbytes, int64_t* bytes_read, void* out) { Status PyReadableFile::Read(int64_t nbytes, std::shared_ptr* out) { PyAcquireGIL lock; - PyObject* bytes_obj; + PyObject* bytes_obj = NULL; ARROW_RETURN_NOT_OK(file_->Read(nbytes, &bytes_obj)); + DCHECK(bytes_obj != NULL); *out = std::make_shared(bytes_obj); - Py_DECREF(bytes_obj); + Py_XDECREF(bytes_obj); return Status::OK(); } @@ -171,13 +176,13 @@ Status PyReadableFile::ReadAt(int64_t position, int64_t nbytes, Status PyReadableFile::GetSize(int64_t* size) { PyAcquireGIL lock; - int64_t current_position; + int64_t current_position = -1; ARROW_RETURN_NOT_OK(file_->Tell(¤t_position)); ARROW_RETURN_NOT_OK(file_->Seek(0, 2)); - int64_t file_size; + int64_t file_size = -1; ARROW_RETURN_NOT_OK(file_->Tell(&file_size)); // Restore previous file position diff --git a/cpp/src/arrow/python/io.h b/cpp/src/arrow/python/io.h index f550de7b2848c..0632d28faf789 100644 --- a/cpp/src/arrow/python/io.h +++ b/cpp/src/arrow/python/io.h @@ -18,6 +18,8 @@ #ifndef PYARROW_IO_H #define PYARROW_IO_H +#include + #include "arrow/io/interfaces.h" #include "arrow/io/memory.h" #include "arrow/util/visibility.h" diff --git a/cpp/src/arrow/python/numpy_convert.cc b/cpp/src/arrow/python/numpy_convert.cc index 9ed2d73d42b57..c2d055fceed5a 100644 --- a/cpp/src/arrow/python/numpy_convert.cc +++ b/cpp/src/arrow/python/numpy_convert.cc @@ -84,6 +84,9 @@ NumPyBuffer::~NumPyBuffer() { Py_XDECREF(arr_); } break; Status GetTensorType(PyObject* dtype, std::shared_ptr* out) { + if (!PyArray_DescrCheck(dtype)) { + return Status::TypeError("Did not pass numpy.dtype object"); + } PyArray_Descr* descr = reinterpret_cast(dtype); int type_num = cast_npy_type_compat(descr->type_num); @@ -145,8 +148,14 @@ Status GetNumPyType(const DataType& type, int* type_num) { } Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr* out) { + if (!PyArray_DescrCheck(dtype)) { + return Status::TypeError("Did not pass numpy.dtype object"); + } PyArray_Descr* descr = reinterpret_cast(dtype); + return NumPyDtypeToArrow(descr, out); +} +Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out) { int type_num = cast_npy_type_compat(descr->type_num); switch (type_num) { diff --git a/cpp/src/arrow/python/numpy_convert.h b/cpp/src/arrow/python/numpy_convert.h index 93c4848926cfc..220e38f2e1e02 100644 --- a/cpp/src/arrow/python/numpy_convert.h +++ b/cpp/src/arrow/python/numpy_convert.h @@ -56,6 +56,8 @@ bool is_contiguous(PyObject* array); ARROW_EXPORT Status NumPyDtypeToArrow(PyObject* dtype, std::shared_ptr* out); +ARROW_EXPORT +Status NumPyDtypeToArrow(PyArray_Descr* descr, std::shared_ptr* out); Status GetTensorType(PyObject* dtype, std::shared_ptr* out); Status GetNumPyType(const DataType& type, int* type_num); diff --git a/cpp/src/arrow/python/numpy_interop.h b/cpp/src/arrow/python/numpy_interop.h index b93200cc8972d..8c569e232c121 100644 --- a/cpp/src/arrow/python/numpy_interop.h +++ b/cpp/src/arrow/python/numpy_interop.h @@ -40,6 +40,7 @@ #endif #include +#include #include namespace arrow { diff --git a/cpp/src/arrow/python/numpy_to_arrow.cc b/cpp/src/arrow/python/numpy_to_arrow.cc index f21b40ed3c246..a1161fe32e100 100644 --- a/cpp/src/arrow/python/numpy_to_arrow.cc +++ b/cpp/src/arrow/python/numpy_to_arrow.cc @@ -29,6 +29,7 @@ #include #include #include +#include #include #include "arrow/array.h" @@ -175,7 +176,7 @@ static Status AppendObjectBinaries(PyArrayObject* arr, PyArrayObject* mask, continue; } else if (!PyBytes_Check(obj)) { std::stringstream ss; - ss << "Error converting to Python objects to bytes: "; + ss << "Error converting from Python objects to bytes: "; RETURN_NOT_OK(InvalidConversion(obj, "str, bytes", &ss)); return Status::Invalid(ss.str()); } @@ -230,7 +231,7 @@ static Status AppendObjectStrings(PyArrayObject* arr, PyArrayObject* mask, int64 *have_bytes = true; } else { std::stringstream ss; - ss << "Error converting to Python objects to String/UTF8: "; + ss << "Error converting from Python objects to String/UTF8: "; RETURN_NOT_OK(InvalidConversion(obj, "str, bytes", &ss)); return Status::Invalid(ss.str()); } @@ -278,7 +279,7 @@ static Status AppendObjectFixedWidthBytes(PyArrayObject* arr, PyArrayObject* mas tmp_obj.reset(obj); } else if (!PyBytes_Check(obj)) { std::stringstream ss; - ss << "Error converting to Python objects to FixedSizeBinary: "; + ss << "Error converting from Python objects to FixedSizeBinary: "; RETURN_NOT_OK(InvalidConversion(obj, "str, bytes", &ss)); return Status::Invalid(ss.str()); } @@ -1008,10 +1009,21 @@ Status NumPyConverter::ConvertObjectsInfer() { return ConvertTimes(); } else if (PyObject_IsInstance(const_cast(obj), Decimal.obj())) { return ConvertDecimals(); - } else if (PyList_Check(obj) || PyArray_Check(obj)) { + } else if (PyList_Check(obj)) { std::shared_ptr inferred_type; RETURN_NOT_OK(InferArrowType(obj, &inferred_type)); return ConvertLists(inferred_type); + } else if (PyArray_Check(obj)) { + std::shared_ptr inferred_type; + PyArray_Descr* dtype = PyArray_DESCR(reinterpret_cast(obj)); + + if (dtype->type_num == NPY_OBJECT) { + RETURN_NOT_OK(InferArrowType(obj, &inferred_type)); + } else { + RETURN_NOT_OK( + NumPyDtypeToArrow(reinterpret_cast(dtype), &inferred_type)); + } + return ConvertLists(inferred_type); } else { const std::string supported_types = "string, bool, float, int, date, time, decimal, list, array"; diff --git a/cpp/src/arrow/record_batch.cc b/cpp/src/arrow/record_batch.cc index 60932bdf3e4bb..d418cc4a2e66c 100644 --- a/cpp/src/arrow/record_batch.cc +++ b/cpp/src/arrow/record_batch.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "arrow/array.h" #include "arrow/status.h" diff --git a/cpp/src/arrow/stl-test.cc b/cpp/src/arrow/stl-test.cc new file mode 100644 index 0000000000000..c85baa3a11e3f --- /dev/null +++ b/cpp/src/arrow/stl-test.cc @@ -0,0 +1,78 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "gtest/gtest.h" + +#include "arrow/stl.h" + +namespace arrow { +namespace stl { + +TEST(TestSchemaFromTuple, PrimitiveTypesVector) { + Schema expected_schema( + {field("column1", int8(), false), field("column2", int16(), false), + field("column3", int32(), false), field("column4", int64(), false), + field("column5", uint8(), false), field("column6", uint16(), false), + field("column7", uint32(), false), field("column8", uint64(), false), + field("column9", boolean(), false), field("column10", utf8(), false)}); + + std::shared_ptr schema = + SchemaFromTuple>:: + MakeSchema(std::vector({"column1", "column2", "column3", "column4", + "column5", "column6", "column7", "column8", + "column9", "column10"})); + ASSERT_TRUE(expected_schema.Equals(*schema)); +} + +TEST(TestSchemaFromTuple, PrimitiveTypesTuple) { + Schema expected_schema( + {field("column1", int8(), false), field("column2", int16(), false), + field("column3", int32(), false), field("column4", int64(), false), + field("column5", uint8(), false), field("column6", uint16(), false), + field("column7", uint32(), false), field("column8", uint64(), false), + field("column9", boolean(), false), field("column10", utf8(), false)}); + + std::shared_ptr schema = SchemaFromTuple< + std::tuple>::MakeSchema(std::make_tuple("column1", "column2", + "column3", "column4", + "column5", "column6", + "column7", "column8", + "column9", "column10")); + ASSERT_TRUE(expected_schema.Equals(*schema)); +} + +TEST(TestSchemaFromTuple, SimpleList) { + Schema expected_schema({field("column1", list(utf8()), false)}); + std::shared_ptr schema = + SchemaFromTuple>>::MakeSchema({"column1"}); + + ASSERT_TRUE(expected_schema.Equals(*schema)); +} + +TEST(TestSchemaFromTuple, NestedList) { + Schema expected_schema({field("column1", list(list(boolean())), false)}); + std::shared_ptr schema = + SchemaFromTuple>>>::MakeSchema( + {"column1"}); + + ASSERT_TRUE(expected_schema.Equals(*schema)); +} + +} // namespace stl +} // namespace arrow diff --git a/cpp/src/arrow/stl.h b/cpp/src/arrow/stl.h new file mode 100644 index 0000000000000..1e31ca769ae0b --- /dev/null +++ b/cpp/src/arrow/stl.h @@ -0,0 +1,154 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#ifndef ARROW_STL_H +#define ARROW_STL_H + +#include +#include +#include +#include + +#include "arrow/type.h" + +namespace arrow { + +class Schema; + +namespace stl { + +/// Traits meta class to map standard C/C++ types to equivalent Arrow types. +template +struct ConversionTraits {}; + +#define ARROW_STL_CONVERSION(c_type, ArrowType_) \ + template <> \ + struct ConversionTraits { \ + using ArrowType = ArrowType_; \ + constexpr static bool nullable = false; \ + }; + +ARROW_STL_CONVERSION(bool, BooleanType) +ARROW_STL_CONVERSION(int8_t, Int8Type) +ARROW_STL_CONVERSION(int16_t, Int16Type) +ARROW_STL_CONVERSION(int32_t, Int32Type) +ARROW_STL_CONVERSION(int64_t, Int64Type) +ARROW_STL_CONVERSION(uint8_t, UInt8Type) +ARROW_STL_CONVERSION(uint16_t, UInt16Type) +ARROW_STL_CONVERSION(uint32_t, UInt32Type) +ARROW_STL_CONVERSION(uint64_t, UInt64Type) +ARROW_STL_CONVERSION(float, FloatType) +ARROW_STL_CONVERSION(double, DoubleType) +ARROW_STL_CONVERSION(std::string, StringType) + +template +struct ConversionTraits> { + using ArrowType = meta::ListType::ArrowType>; + constexpr static bool nullable = false; +}; + +/// Build an arrow::Schema based upon the types defined in a std::tuple-like structure. +/// +/// While the type information is available at compile-time, we still need to add the +/// column names at runtime, thus these methods are not constexpr. +template ::value> +struct SchemaFromTuple { + using Element = typename std::tuple_element::type; + using ArrowType = typename ConversionTraits::ArrowType; + + // Implementations that take a vector-like object for the column names. + + /// Recursively build a vector of arrow::Field from the defined types. + /// + /// In most cases MakeSchema is the better entrypoint for the Schema creation. + static std::vector> MakeSchemaRecursion( + const std::vector& names) { + std::vector> ret = + SchemaFromTuple::MakeSchemaRecursion(names); + ret.push_back(field(names[N - 1], std::make_shared(), + ConversionTraits::nullable)); + return ret; + } + + /// Build a Schema from the types of the tuple-like structure passed in as template + /// parameter assign the column names at runtime. + /// + /// An example usage of this API can look like the following: + /// + /// \code{.cpp} + /// using TupleType = std::tuple>; + /// std::shared_ptr schema = + /// SchemaFromTuple::MakeSchema({"int_column", "list_of_strings_column"}); + /// \endcode + static std::shared_ptr MakeSchema(const std::vector& names) { + return std::make_shared(MakeSchemaRecursion(names)); + } + + // Implementations that take a tuple-like object for the column names. + + /// Recursively build a vector of arrow::Field from the defined types. + /// + /// In most cases MakeSchema is the better entrypoint for the Schema creation. + template + static std::vector> MakeSchemaRecursionT( + const NamesTuple& names) { + std::vector> ret = + SchemaFromTuple::MakeSchemaRecursionT(names); + ret.push_back(field(std::get(names), std::make_shared(), + ConversionTraits::nullable)); + return ret; + } + + /// Build a Schema from the types of the tuple-like structure passed in as template + /// parameter assign the column names at runtime. + /// + /// An example usage of this API can look like the following: + /// + /// \code{.cpp} + /// using TupleType = std::tuple>; + /// std::shared_ptr schema = + /// SchemaFromTuple::MakeSchema({"int_column", "list_of_strings_column"}); + /// \endcode + template + static std::shared_ptr MakeSchema(const NamesTuple& names) { + return std::make_shared(MakeSchemaRecursionT(names)); + } +}; + +template +struct SchemaFromTuple { + static std::vector> MakeSchemaRecursion( + const std::vector& names) { + std::vector> ret; + ret.reserve(names.size()); + return ret; + } + + template + static std::vector> MakeSchemaRecursionT( + const NamesTuple& names) { + std::vector> ret; + ret.reserve(std::tuple_size::value); + return ret; + } +}; +/// @endcond + +} // namespace stl +} // namespace arrow + +#endif // ARROW_STL_H diff --git a/cpp/src/arrow/table-test.cc b/cpp/src/arrow/table-test.cc index 3f1c6be3a87f6..99e4dd5db5146 100644 --- a/cpp/src/arrow/table-test.cc +++ b/cpp/src/arrow/table-test.cc @@ -108,6 +108,21 @@ TEST_F(TestChunkedArray, EqualsDifferingLengths) { ASSERT_TRUE(one_->Equals(*another_.get())); } +TEST_F(TestChunkedArray, SliceEquals) { + arrays_one_.push_back(MakeRandomArray(100)); + arrays_one_.push_back(MakeRandomArray(50)); + arrays_one_.push_back(MakeRandomArray(50)); + Construct(); + + std::shared_ptr slice = one_->Slice(125, 50); + ASSERT_EQ(slice->length(), 50); + ASSERT_TRUE(slice->Equals(one_->Slice(125, 50))); + + std::shared_ptr slice2 = one_->Slice(75)->Slice(25)->Slice(25, 50); + ASSERT_EQ(slice2->length(), 50); + ASSERT_TRUE(slice2->Equals(slice)); +} + class TestColumn : public TestChunkedArray { protected: void Construct() override { @@ -158,6 +173,22 @@ TEST_F(TestColumn, ChunksInhomogeneous) { ASSERT_RAISES(Invalid, column_->ValidateData()); } +TEST_F(TestColumn, SliceEquals) { + arrays_one_.push_back(MakeRandomArray(100)); + arrays_one_.push_back(MakeRandomArray(50)); + arrays_one_.push_back(MakeRandomArray(50)); + one_field_ = field("column", int32()); + Construct(); + + std::shared_ptr slice = one_col_->Slice(125, 50); + ASSERT_EQ(slice->length(), 50); + ASSERT_TRUE(slice->Equals(one_col_->Slice(125, 50))); + + std::shared_ptr slice2 = one_col_->Slice(75)->Slice(25)->Slice(25, 50); + ASSERT_EQ(slice2->length(), 50); + ASSERT_TRUE(slice2->Equals(slice)); +} + TEST_F(TestColumn, Equals) { std::vector null_bitmap(100, true); std::vector data(100, 1); diff --git a/cpp/src/arrow/table.cc b/cpp/src/arrow/table.cc index 2cf6c26523965..8cfd67faef1ee 100644 --- a/cpp/src/arrow/table.cc +++ b/cpp/src/arrow/table.cc @@ -22,6 +22,7 @@ #include #include #include +#include #include "arrow/array.h" #include "arrow/record_batch.h" @@ -102,6 +103,30 @@ bool ChunkedArray::Equals(const std::shared_ptr& other) const { return Equals(*other.get()); } +std::shared_ptr ChunkedArray::Slice(int64_t offset, int64_t length) const { + DCHECK_LE(offset, length_); + + int curr_chunk = 0; + while (offset >= chunk(curr_chunk)->length()) { + offset -= chunk(curr_chunk)->length(); + curr_chunk++; + } + + ArrayVector new_chunks; + while (length > 0 && curr_chunk < num_chunks()) { + new_chunks.push_back(chunk(curr_chunk)->Slice(offset, length)); + length -= chunk(curr_chunk)->length() - offset; + offset = 0; + curr_chunk++; + } + + return std::make_shared(new_chunks); +} + +std::shared_ptr ChunkedArray::Slice(int64_t offset) const { + return Slice(offset, length_); +} + Column::Column(const std::shared_ptr& field, const ArrayVector& chunks) : field_(field) { data_ = std::make_shared(chunks); diff --git a/cpp/src/arrow/table.h b/cpp/src/arrow/table.h index c813b32ad36dc..570a650e7fa4a 100644 --- a/cpp/src/arrow/table.h +++ b/cpp/src/arrow/table.h @@ -44,6 +44,7 @@ class ARROW_EXPORT ChunkedArray { /// \return the total length of the chunked array; computed on construction int64_t length() const { return length_; } + /// \return the total number of nulls among all chunks int64_t null_count() const { return null_count_; } int num_chunks() const { return static_cast(chunks_.size()); } @@ -53,6 +54,20 @@ class ARROW_EXPORT ChunkedArray { const ArrayVector& chunks() const { return chunks_; } + /// \brief Construct a zero-copy slice of the chunked array with the + /// indicated offset and length + /// + /// \param[in] offset the position of the first element in the constructed + /// slice + /// \param[in] length the length of the slice. If there are not enough + /// elements in the chunked array, the length will be adjusted accordingly + /// + /// \return a new object wrapped in std::shared_ptr + std::shared_ptr Slice(int64_t offset, int64_t length) const; + + /// \brief Slice from offset until end of the chunked array + std::shared_ptr Slice(int64_t offset) const; + std::shared_ptr type() const; bool Equals(const ChunkedArray& other) const; @@ -67,8 +82,9 @@ class ARROW_EXPORT ChunkedArray { ARROW_DISALLOW_COPY_AND_ASSIGN(ChunkedArray); }; +/// \class Column /// \brief An immutable column data structure consisting of a field (type -/// metadata) and a logical chunked data array +/// metadata) and a chunked data array class ARROW_EXPORT Column { public: Column(const std::shared_ptr& field, const ArrayVector& chunks); @@ -97,6 +113,24 @@ class ARROW_EXPORT Column { /// \return the column's data as a chunked logical array std::shared_ptr data() const { return data_; } + /// \brief Construct a zero-copy slice of the column with the indicated + /// offset and length + /// + /// \param[in] offset the position of the first element in the constructed + /// slice + /// \param[in] length the length of the slice. If there are not enough + /// elements in the column, the length will be adjusted accordingly + /// + /// \return a new object wrapped in std::shared_ptr + std::shared_ptr Slice(int64_t offset, int64_t length) const { + return std::make_shared(field_, data_->Slice(offset, length)); + } + + /// \brief Slice from offset until end of the column + std::shared_ptr Slice(int64_t offset) const { + return std::make_shared(field_, data_->Slice(offset)); + } + bool Equals(const Column& other) const; bool Equals(const std::shared_ptr& other) const; diff --git a/cpp/src/arrow/table_builder.cc b/cpp/src/arrow/table_builder.cc index 379d886deacba..8e9babcc3997a 100644 --- a/cpp/src/arrow/table_builder.cc +++ b/cpp/src/arrow/table_builder.cc @@ -21,6 +21,7 @@ #include #include #include +#include #include "arrow/array.h" #include "arrow/builder.h" diff --git a/cpp/src/arrow/type.cc b/cpp/src/arrow/type.cc index 31ad53458112c..0a2889f040026 100644 --- a/cpp/src/arrow/type.cc +++ b/cpp/src/arrow/type.cc @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include "arrow/array.h" #include "arrow/compare.h" diff --git a/cpp/src/arrow/type.h b/cpp/src/arrow/type.h index 009e07db07744..cfee6fd0e2363 100644 --- a/cpp/src/arrow/type.h +++ b/cpp/src/arrow/type.h @@ -407,6 +407,19 @@ class ARROW_EXPORT ListType : public NestedType { std::string name() const override { return "list"; } }; +namespace meta { + +/// Additional ListType class that can be instantiated with only compile-time arguments. +template +class ARROW_EXPORT ListType : public ::arrow::ListType { + public: + using ValueType = T; + + ListType() : ::arrow::ListType(std::make_shared()) {} +}; + +} // namespace meta + // BinaryType type is represents lists of 1-byte values. class ARROW_EXPORT BinaryType : public DataType, public NoExtraMeta { public: diff --git a/cpp/src/arrow/type_traits.h b/cpp/src/arrow/type_traits.h index 4bfce9b5f0c53..ede52e9b84bb6 100644 --- a/cpp/src/arrow/type_traits.h +++ b/cpp/src/arrow/type_traits.h @@ -18,6 +18,7 @@ #ifndef ARROW_TYPE_TRAITS_H #define ARROW_TYPE_TRAITS_H +#include #include #include "arrow/type_fwd.h" diff --git a/cpp/src/arrow/util/io-util.h b/cpp/src/arrow/util/io-util.h index 7e2a94ca82320..d1af6c666a156 100644 --- a/cpp/src/arrow/util/io-util.h +++ b/cpp/src/arrow/util/io-util.h @@ -19,6 +19,7 @@ #define ARROW_UTIL_IO_UTIL_H #include +#include #include "arrow/buffer.h" #include "arrow/io/interfaces.h" diff --git a/cpp/src/plasma/client.cc b/cpp/src/plasma/client.cc index d74c0f412d97f..6e9b6968a8673 100644 --- a/cpp/src/plasma/client.cc +++ b/cpp/src/plasma/client.cc @@ -130,7 +130,7 @@ void PlasmaClient::increment_object_count(const ObjectID& object_id, PlasmaObjec // Increment the count of the number of objects in the memory-mapped file // that are being used. The corresponding decrement should happen in // PlasmaClient::Release. - auto entry = mmap_table_.find(object->handle.store_fd); + auto entry = mmap_table_.find(object->store_fd); ARROW_CHECK(entry != mmap_table_.end()); ARROW_CHECK(entry->second.count >= 0); // Update the in_use_object_bytes_. @@ -157,7 +157,10 @@ Status PlasmaClient::Create(const ObjectID& object_id, int64_t data_size, RETURN_NOT_OK(PlasmaReceive(store_conn_, MessageType_PlasmaCreateReply, &buffer)); ObjectID id; PlasmaObject object; - RETURN_NOT_OK(ReadCreateReply(buffer.data(), buffer.size(), &id, &object)); + int store_fd; + int64_t mmap_size; + RETURN_NOT_OK( + ReadCreateReply(buffer.data(), buffer.size(), &id, &object, &store_fd, &mmap_size)); // If the CreateReply included an error, then the store will not send a file // descriptor. int fd = recv_fd(store_conn_); @@ -167,9 +170,7 @@ Status PlasmaClient::Create(const ObjectID& object_id, int64_t data_size, // The metadata should come right after the data. ARROW_CHECK(object.metadata_offset == object.data_offset + data_size); *data = std::make_shared( - lookup_or_mmap(fd, object.handle.store_fd, object.handle.mmap_size) + - object.data_offset, - data_size); + lookup_or_mmap(fd, store_fd, mmap_size) + object.data_offset, data_size); // If plasma_create is being called from a transfer, then we will not copy the // metadata here. The metadata will be written along with the data streamed // from the transfer. @@ -209,7 +210,7 @@ Status PlasmaClient::Get(const ObjectID* object_ids, int64_t num_objects, ARROW_CHECK(object_entry->second->is_sealed) << "Plasma client called get on an unsealed object that it created"; PlasmaObject* object = &object_entry->second->object; - uint8_t* data = lookup_mmapped_file(object->handle.store_fd); + uint8_t* data = lookup_mmapped_file(object->store_fd); object_buffers[i].data = std::make_shared(data + object->data_offset, object->data_size); object_buffers[i].metadata = std::make_shared( @@ -236,8 +237,19 @@ Status PlasmaClient::Get(const ObjectID* object_ids, int64_t num_objects, std::vector received_object_ids(num_objects); std::vector object_data(num_objects); PlasmaObject* object; + std::vector store_fds; + std::vector mmap_sizes; RETURN_NOT_OK(ReadGetReply(buffer.data(), buffer.size(), received_object_ids.data(), - object_data.data(), num_objects)); + object_data.data(), num_objects, store_fds, mmap_sizes)); + + // We mmap all of the file descriptors here so that we can avoid look them up + // in the subsequent loop based on just the store file descriptor and without + // having to know the relevant file descriptor received from recv_fd. + for (size_t i = 0; i < store_fds.size(); i++) { + int fd = recv_fd(store_conn_); + ARROW_CHECK(fd >= 0); + lookup_or_mmap(fd, store_fds[i], mmap_sizes[i]); + } for (int i = 0; i < num_objects; ++i) { DCHECK(received_object_ids[i] == object_ids[i]); @@ -246,12 +258,6 @@ Status PlasmaClient::Get(const ObjectID* object_ids, int64_t num_objects, // If the object was already in use by the client, then the store should // have returned it. DCHECK_NE(object->data_size, -1); - // We won't use this file descriptor, but the store sent us one, so we - // need to receive it and then close it right away so we don't leak file - // descriptors. - int fd = recv_fd(store_conn_); - close(fd); - ARROW_CHECK(fd >= 0); // We've already filled out the information for this object, so we can // just continue. continue; @@ -259,12 +265,7 @@ Status PlasmaClient::Get(const ObjectID* object_ids, int64_t num_objects, // If we are here, the object was not currently in use, so we need to // process the reply from the object store. if (object->data_size != -1) { - // The object was retrieved. The user will be responsible for releasing - // this object. - int fd = recv_fd(store_conn_); - uint8_t* data = - lookup_or_mmap(fd, object->handle.store_fd, object->handle.mmap_size); - ARROW_CHECK(fd >= 0); + uint8_t* data = lookup_mmapped_file(object->store_fd); // Finish filling out the return values. object_buffers[i].data = std::make_shared(data + object->data_offset, object->data_size); @@ -296,7 +297,7 @@ Status PlasmaClient::UnmapObject(const ObjectID& object_id) { // Decrement the count of the number of objects in this memory-mapped file // that the client is using. The corresponding increment should have // happened in plasma_get. - int fd = object_entry->second->object.handle.store_fd; + int fd = object_entry->second->object.store_fd; auto entry = mmap_table_.find(fd); ARROW_CHECK(entry != mmap_table_.end()); ARROW_CHECK(entry->second.count >= 1); diff --git a/cpp/src/plasma/client.h b/cpp/src/plasma/client.h index 35182f8403201..d6372f44a7f28 100644 --- a/cpp/src/plasma/client.h +++ b/cpp/src/plasma/client.h @@ -31,8 +31,8 @@ #include "arrow/util/visibility.h" #include "plasma/common.h" -using arrow::Status; using arrow::Buffer; +using arrow::Status; namespace plasma { diff --git a/cpp/src/plasma/events.cc b/cpp/src/plasma/events.cc index 4e4ecfaaaca31..ce29e6c321d5d 100644 --- a/cpp/src/plasma/events.cc +++ b/cpp/src/plasma/events.cc @@ -17,6 +17,8 @@ #include "plasma/events.h" +#include + #include namespace plasma { diff --git a/cpp/src/plasma/fling.cc b/cpp/src/plasma/fling.cc index b84648b25a9e7..26afd87066c2b 100644 --- a/cpp/src/plasma/fling.cc +++ b/cpp/src/plasma/fling.cc @@ -23,7 +23,7 @@ void init_msg(struct msghdr* msg, struct iovec* iov, char* buf, size_t buf_len) msg->msg_iov = iov; msg->msg_iovlen = 1; msg->msg_control = buf; - msg->msg_controllen = buf_len; + msg->msg_controllen = static_cast(buf_len); msg->msg_name = NULL; msg->msg_namelen = 0; } @@ -43,7 +43,7 @@ int send_fd(int conn, int fd) { header->cmsg_level = SOL_SOCKET; header->cmsg_type = SCM_RIGHTS; header->cmsg_len = CMSG_LEN(sizeof(int)); - *reinterpret_cast(CMSG_DATA(header)) = fd; + memcpy(CMSG_DATA(header), reinterpret_cast(&fd), sizeof(int)); // Send file descriptor. ssize_t r = sendmsg(conn, &msg, 0); diff --git a/cpp/src/plasma/format/plasma.fbs b/cpp/src/plasma/format/plasma.fbs index ea6dc8bb98da5..33803f7799ba0 100644 --- a/cpp/src/plasma/format/plasma.fbs +++ b/cpp/src/plasma/format/plasma.fbs @@ -89,8 +89,6 @@ struct PlasmaObjectSpec { // Index of the memory segment (= memory mapped file) that // this object is allocated in. segment_index: int; - // Size in bytes of this segment (needed to call mmap). - mmap_size: ulong; // The offset in bytes in the memory mapped file of the data. data_offset: ulong; // The size in bytes of the data. @@ -117,6 +115,12 @@ table PlasmaCreateReply { plasma_object: PlasmaObjectSpec; // Error that occurred for this call. error: PlasmaError; + // The file descriptor in the store that corresponds to the file descriptor + // being sent to the client right after this message. + store_fd: int; + // The size in bytes of the segment for the store file descriptor (needed to + // call mmap). + mmap_size: long; } table PlasmaAbortRequest { @@ -156,9 +160,17 @@ table PlasmaGetReply { // objects if not all requested objects are stored and sealed // in the local Plasma store. object_ids: [string]; - // Plasma object information, in the same order as their IDs. + // Plasma object information, in the same order as their IDs. The number of + // elements in both object_ids and plasma_objects arrays must agree. plasma_objects: [PlasmaObjectSpec]; - // The number of elements in both object_ids and plasma_objects arrays must agree. + // A list of the file descriptors in the store that correspond to the file + // descriptors being sent to the client. The length of this list is the number + // of file descriptors that the store will send to the client after this + // message. + store_fds: [int]; + // Size in bytes of the segment for each store file descriptor (needed to call + // mmap). This list must have the same length as store_fds. + mmap_sizes: [long]; } table PlasmaReleaseRequest { diff --git a/cpp/src/plasma/malloc.cc b/cpp/src/plasma/malloc.cc index 52d362013f1ae..3c5d107b2bbe3 100644 --- a/cpp/src/plasma/malloc.cc +++ b/cpp/src/plasma/malloc.cc @@ -197,4 +197,14 @@ void get_malloc_mapinfo(void* addr, int* fd, int64_t* map_size, ptrdiff_t* offse *offset = 0; } +int64_t get_mmap_size(int fd) { + for (const auto& entry : mmap_records) { + if (entry.second.fd == fd) { + return entry.second.size; + } + } + ARROW_LOG(FATAL) << "failed to find entry in mmap_records for fd " << fd; + return -1; // This code is never reached. +} + void set_malloc_granularity(int value) { change_mparam(M_GRANULARITY, value); } diff --git a/cpp/src/plasma/malloc.h b/cpp/src/plasma/malloc.h index 0df720db59817..cb8c600b14b3b 100644 --- a/cpp/src/plasma/malloc.h +++ b/cpp/src/plasma/malloc.h @@ -23,6 +23,12 @@ void get_malloc_mapinfo(void* addr, int* fd, int64_t* map_length, ptrdiff_t* offset); +/// Get the mmap size corresponding to a specific file descriptor. +/// +/// @param fd The file descriptor to look up. +/// @return The size of the corresponding memory-mapped file. +int64_t get_mmap_size(int fd); + void set_malloc_granularity(int value); #endif // MALLOC_H diff --git a/cpp/src/plasma/plasma.h b/cpp/src/plasma/plasma.h index 603ff8a4fac6c..bb9cdae601146 100644 --- a/cpp/src/plasma/plasma.h +++ b/cpp/src/plasma/plasma.h @@ -27,6 +27,7 @@ #include #include // pid_t +#include #include #include #include @@ -64,20 +65,12 @@ struct Client; /// Mapping from object IDs to type and status of the request. typedef std::unordered_map ObjectRequestMap; -/// Handle to access memory mapped file and map it into client address space. -struct object_handle { +// TODO(pcm): Replace this by the flatbuffers message PlasmaObjectSpec. +struct PlasmaObject { /// The file descriptor of the memory mapped file in the store. It is used as /// a unique identifier of the file in the client to look up the corresponding /// file descriptor on the client's side. int store_fd; - /// The size in bytes of the memory mapped file. - int64_t mmap_size; -}; - -// TODO(pcm): Replace this by the flatbuffers message PlasmaObjectSpec. -struct PlasmaObject { - /// Handle for memory mapped file the object is stored in. - object_handle handle; /// The offset in bytes in the memory mapped file of the data. ptrdiff_t data_offset; /// The offset in bytes in the memory mapped file of the metadata. diff --git a/cpp/src/plasma/protocol.cc b/cpp/src/plasma/protocol.cc index c0ebb88fe5019..6c0bc0cab28bb 100644 --- a/cpp/src/plasma/protocol.cc +++ b/cpp/src/plasma/protocol.cc @@ -73,30 +73,32 @@ Status ReadCreateRequest(uint8_t* data, size_t size, ObjectID* object_id, return Status::OK(); } -Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, - int error_code) { +Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, int error_code, + int64_t mmap_size) { flatbuffers::FlatBufferBuilder fbb; - PlasmaObjectSpec plasma_object(object->handle.store_fd, object->handle.mmap_size, - object->data_offset, object->data_size, + PlasmaObjectSpec plasma_object(object->store_fd, object->data_offset, object->data_size, object->metadata_offset, object->metadata_size); - auto message = - CreatePlasmaCreateReply(fbb, fbb.CreateString(object_id.binary()), &plasma_object, - static_cast(error_code)); + auto message = CreatePlasmaCreateReply( + fbb, fbb.CreateString(object_id.binary()), &plasma_object, + static_cast(error_code), object->store_fd, mmap_size); return PlasmaSend(sock, MessageType_PlasmaCreateReply, &fbb, message); } Status ReadCreateReply(uint8_t* data, size_t size, ObjectID* object_id, - PlasmaObject* object) { + PlasmaObject* object, int* store_fd, int64_t* mmap_size) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(verify_flatbuffer(message, data, size)); *object_id = ObjectID::from_binary(message->object_id()->str()); - object->handle.store_fd = message->plasma_object()->segment_index(); - object->handle.mmap_size = message->plasma_object()->mmap_size(); + object->store_fd = message->plasma_object()->segment_index(); object->data_offset = message->plasma_object()->data_offset(); object->data_size = message->plasma_object()->data_size(); object->metadata_offset = message->plasma_object()->metadata_offset(); object->metadata_size = message->plasma_object()->metadata_size(); + + *store_fd = message->store_fd(); + *mmap_size = message->mmap_size(); + return plasma_error_status(message->error()); } @@ -389,24 +391,29 @@ Status ReadGetRequest(uint8_t* data, size_t size, std::vector& object_ Status SendGetReply( int sock, ObjectID object_ids[], std::unordered_map& plasma_objects, - int64_t num_objects) { + int64_t num_objects, const std::vector& store_fds, + const std::vector& mmap_sizes) { flatbuffers::FlatBufferBuilder fbb; std::vector objects; - for (int i = 0; i < num_objects; ++i) { + ARROW_CHECK(store_fds.size() == mmap_sizes.size()); + + for (int64_t i = 0; i < num_objects; ++i) { const PlasmaObject& object = plasma_objects[object_ids[i]]; - objects.push_back(PlasmaObjectSpec(object.handle.store_fd, object.handle.mmap_size, - object.data_offset, object.data_size, - object.metadata_offset, object.metadata_size)); + objects.push_back(PlasmaObjectSpec(object.store_fd, object.data_offset, + object.data_size, object.metadata_offset, + object.metadata_size)); } auto message = CreatePlasmaGetReply(fbb, to_flatbuffer(&fbb, object_ids, num_objects), - fbb.CreateVectorOfStructs(objects.data(), num_objects)); + fbb.CreateVectorOfStructs(objects.data(), num_objects), + fbb.CreateVector(store_fds), fbb.CreateVector(mmap_sizes)); return PlasmaSend(sock, MessageType_PlasmaGetReply, &fbb, message); } Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], - PlasmaObject plasma_objects[], int64_t num_objects) { + PlasmaObject plasma_objects[], int64_t num_objects, + std::vector& store_fds, std::vector& mmap_sizes) { DCHECK(data); auto message = flatbuffers::GetRoot(data); DCHECK(verify_flatbuffer(message, data, size)); @@ -415,13 +422,17 @@ Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], } for (uoffset_t i = 0; i < num_objects; ++i) { const PlasmaObjectSpec* object = message->plasma_objects()->Get(i); - plasma_objects[i].handle.store_fd = object->segment_index(); - plasma_objects[i].handle.mmap_size = object->mmap_size(); + plasma_objects[i].store_fd = object->segment_index(); plasma_objects[i].data_offset = object->data_offset(); plasma_objects[i].data_size = object->data_size(); plasma_objects[i].metadata_offset = object->metadata_offset(); plasma_objects[i].metadata_size = object->metadata_size(); } + ARROW_CHECK(message->store_fds()->size() == message->mmap_sizes()->size()); + for (uoffset_t i = 0; i < message->store_fds()->size(); i++) { + store_fds.push_back(message->store_fds()->Get(i)); + mmap_sizes.push_back(message->mmap_sizes()->Get(i)); + } return Status::OK(); } diff --git a/cpp/src/plasma/protocol.h b/cpp/src/plasma/protocol.h index e8c334f9181fc..101a3faa7675e 100644 --- a/cpp/src/plasma/protocol.h +++ b/cpp/src/plasma/protocol.h @@ -18,6 +18,8 @@ #ifndef PLASMA_PROTOCOL_H #define PLASMA_PROTOCOL_H +#include +#include #include #include "arrow/status.h" @@ -46,10 +48,11 @@ Status SendCreateRequest(int sock, ObjectID object_id, int64_t data_size, Status ReadCreateRequest(uint8_t* data, size_t size, ObjectID* object_id, int64_t* data_size, int64_t* metadata_size); -Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, int error); +Status SendCreateReply(int sock, ObjectID object_id, PlasmaObject* object, int error, + int64_t mmap_size); Status ReadCreateReply(uint8_t* data, size_t size, ObjectID* object_id, - PlasmaObject* object); + PlasmaObject* object, int* store_fd, int64_t* mmap_size); Status SendAbortRequest(int sock, ObjectID object_id); @@ -81,10 +84,12 @@ Status ReadGetRequest(uint8_t* data, size_t size, std::vector& object_ Status SendGetReply( int sock, ObjectID object_ids[], std::unordered_map& plasma_objects, - int64_t num_objects); + int64_t num_objects, const std::vector& store_fds, + const std::vector& mmap_sizes); Status ReadGetReply(uint8_t* data, size_t size, ObjectID object_ids[], - PlasmaObject plasma_objects[], int64_t num_objects); + PlasmaObject plasma_objects[], int64_t num_objects, + std::vector& store_fds, std::vector& mmap_sizes); /* Plasma Release message functions. */ diff --git a/cpp/src/plasma/store.cc b/cpp/src/plasma/store.cc index dde7f9cdfa8eb..316a27f63f680 100644 --- a/cpp/src/plasma/store.cc +++ b/cpp/src/plasma/store.cc @@ -44,9 +44,11 @@ #include #include +#include #include #include #include +#include #include #include "plasma/common.h" @@ -192,8 +194,7 @@ int PlasmaStore::create_object(const ObjectID& object_id, int64_t data_size, entry->state = PLASMA_CREATED; store_info_.objects[object_id] = std::move(entry); - result->handle.store_fd = fd; - result->handle.mmap_size = map_size; + result->store_fd = fd; result->data_offset = offset; result->metadata_offset = offset + data_size; result->data_size = data_size; @@ -211,8 +212,7 @@ void PlasmaObject_init(PlasmaObject* object, ObjectTableEntry* entry) { DCHECK(object != NULL); DCHECK(entry != NULL); DCHECK(entry->state == PLASMA_SEALED); - object->handle.store_fd = entry->fd; - object->handle.mmap_size = entry->map_size; + object->store_fd = entry->fd; object->data_offset = entry->offset; object->metadata_offset = entry->offset + entry->info.data_size; object->data_size = entry->info.data_size; @@ -220,34 +220,44 @@ void PlasmaObject_init(PlasmaObject* object, ObjectTableEntry* entry) { } void PlasmaStore::return_from_get(GetRequest* get_req) { + // Figure out how many file descriptors we need to send. + std::unordered_set fds_to_send; + std::vector store_fds; + std::vector mmap_sizes; + for (const auto& object_id : get_req->object_ids) { + PlasmaObject& object = get_req->objects[object_id]; + int fd = object.store_fd; + if (object.data_size != -1 && fds_to_send.count(fd) == 0) { + fds_to_send.insert(fd); + store_fds.push_back(fd); + mmap_sizes.push_back(get_mmap_size(fd)); + } + } + // Send the get reply to the client. Status s = SendGetReply(get_req->client->fd, &get_req->object_ids[0], get_req->objects, - get_req->object_ids.size()); + get_req->object_ids.size(), store_fds, mmap_sizes); warn_if_sigpipe(s.ok() ? 0 : -1, get_req->client->fd); // If we successfully sent the get reply message to the client, then also send // the file descriptors. if (s.ok()) { // Send all of the file descriptors for the present objects. - for (const auto& object_id : get_req->object_ids) { - PlasmaObject& object = get_req->objects[object_id]; - // We use the data size to indicate whether the object is present or not. - if (object.data_size != -1) { - int error_code = send_fd(get_req->client->fd, object.handle.store_fd); - // If we failed to send the file descriptor, loop until we have sent it - // successfully. TODO(rkn): This is problematic for two reasons. First - // of all, sending the file descriptor should just succeed without any - // errors, but sometimes I see a "Message too long" error number. - // Second, looping like this allows a client to potentially block the - // plasma store event loop which should never happen. - while (error_code < 0) { - if (errno == EMSGSIZE) { - ARROW_LOG(WARNING) << "Failed to send file descriptor, retrying."; - error_code = send_fd(get_req->client->fd, object.handle.store_fd); - continue; - } - warn_if_sigpipe(error_code, get_req->client->fd); - break; + for (int store_fd : store_fds) { + int error_code = send_fd(get_req->client->fd, store_fd); + // If we failed to send the file descriptor, loop until we have sent it + // successfully. TODO(rkn): This is problematic for two reasons. First + // of all, sending the file descriptor should just succeed without any + // errors, but sometimes I see a "Message too long" error number. + // Second, looping like this allows a client to potentially block the + // plasma store event loop which should never happen. + while (error_code < 0) { + if (errno == EMSGSIZE) { + ARROW_LOG(WARNING) << "Failed to send file descriptor, retrying."; + error_code = send_fd(get_req->client->fd, store_fd); + continue; } + warn_if_sigpipe(error_code, get_req->client->fd); + break; } } } @@ -640,10 +650,15 @@ Status PlasmaStore::process_message(Client* client) { ReadCreateRequest(input, input_size, &object_id, &data_size, &metadata_size)); int error_code = create_object(object_id, data_size, metadata_size, client, &object); - HANDLE_SIGPIPE(SendCreateReply(client->fd, object_id, &object, error_code), - client->fd); + int64_t mmap_size = 0; + if (error_code == PlasmaError_OK) { + mmap_size = get_mmap_size(object.store_fd); + } + HANDLE_SIGPIPE( + SendCreateReply(client->fd, object_id, &object, error_code, mmap_size), + client->fd); if (error_code == PlasmaError_OK) { - warn_if_sigpipe(send_fd(client->fd, object.handle.store_fd), client->fd); + warn_if_sigpipe(send_fd(client->fd, object.store_fd), client->fd); } } break; case MessageType_PlasmaAbortRequest: { diff --git a/cpp/src/plasma/store.h b/cpp/src/plasma/store.h index 7eada5a126991..7e716d284f389 100644 --- a/cpp/src/plasma/store.h +++ b/cpp/src/plasma/store.h @@ -19,7 +19,9 @@ #define PLASMA_STORE_H #include +#include #include +#include #include #include "plasma/common.h" diff --git a/cpp/src/plasma/test/serialization_tests.cc b/cpp/src/plasma/test/serialization_tests.cc index b593b6ae94890..656b2cc6b9bca 100644 --- a/cpp/src/plasma/test/serialization_tests.cc +++ b/cpp/src/plasma/test/serialization_tests.cc @@ -63,8 +63,7 @@ PlasmaObject random_plasma_object(void) { int random = rand_r(&seed); PlasmaObject object; memset(&object, 0, sizeof(object)); - object.handle.store_fd = random + 7; - object.handle.mmap_size = random + 42; + object.store_fd = random + 7; object.data_offset = random + 1; object.metadata_offset = random + 2; object.data_size = random + 3; @@ -94,13 +93,19 @@ TEST(PlasmaSerialization, CreateReply) { int fd = create_temp_file(); ObjectID object_id1 = ObjectID::from_random(); PlasmaObject object1 = random_plasma_object(); - ARROW_CHECK_OK(SendCreateReply(fd, object_id1, &object1, 0)); + int64_t mmap_size1 = 1000000; + ARROW_CHECK_OK(SendCreateReply(fd, object_id1, &object1, 0, mmap_size1)); std::vector data = read_message_from_file(fd, MessageType_PlasmaCreateReply); ObjectID object_id2; PlasmaObject object2; memset(&object2, 0, sizeof(object2)); - ARROW_CHECK_OK(ReadCreateReply(data.data(), data.size(), &object_id2, &object2)); + int store_fd; + int64_t mmap_size2; + ARROW_CHECK_OK(ReadCreateReply(data.data(), data.size(), &object_id2, &object2, + &store_fd, &mmap_size2)); ASSERT_EQ(object_id1, object_id2); + ASSERT_EQ(object1.store_fd, store_fd); + ASSERT_EQ(mmap_size1, mmap_size2); ASSERT_EQ(memcmp(&object1, &object2, sizeof(object1)), 0); close(fd); } @@ -158,13 +163,20 @@ TEST(PlasmaSerialization, GetReply) { std::unordered_map plasma_objects; plasma_objects[object_ids[0]] = random_plasma_object(); plasma_objects[object_ids[1]] = random_plasma_object(); - ARROW_CHECK_OK(SendGetReply(fd, object_ids, plasma_objects, 2)); + std::vector store_fds = {1, 2, 3}; + std::vector mmap_sizes = {100, 200, 300}; + ARROW_CHECK_OK(SendGetReply(fd, object_ids, plasma_objects, 2, store_fds, mmap_sizes)); + std::vector data = read_message_from_file(fd, MessageType_PlasmaGetReply); ObjectID object_ids_return[2]; PlasmaObject plasma_objects_return[2]; + std::vector store_fds_return; + std::vector mmap_sizes_return; memset(&plasma_objects_return, 0, sizeof(plasma_objects_return)); ARROW_CHECK_OK(ReadGetReply(data.data(), data.size(), object_ids_return, - &plasma_objects_return[0], 2)); + &plasma_objects_return[0], 2, store_fds_return, + mmap_sizes_return)); + ASSERT_EQ(object_ids[0], object_ids_return[0]); ASSERT_EQ(object_ids[1], object_ids_return[1]); ASSERT_EQ(memcmp(&plasma_objects[object_ids[0]], &plasma_objects_return[0], @@ -173,6 +185,8 @@ TEST(PlasmaSerialization, GetReply) { ASSERT_EQ(memcmp(&plasma_objects[object_ids[1]], &plasma_objects_return[1], sizeof(PlasmaObject)), 0); + ASSERT_TRUE(store_fds == store_fds_return); + ASSERT_TRUE(mmap_sizes == mmap_sizes_return); close(fd); } diff --git a/dev/gen_apidocs/create_documents.sh b/dev/gen_apidocs/create_documents.sh index 54031262b3a5d..3100d3b984b3a 100755 --- a/dev/gen_apidocs/create_documents.sh +++ b/dev/gen_apidocs/create_documents.sh @@ -87,8 +87,6 @@ if [ -f Makefile ]; then # Ensure updating to prevent auto re-configure touch configure **/Makefile make distclean - # Work around for 'make distclean' removes doc/reference/xml/ - git checkout doc/reference/xml fi ./autogen.sh rm -rf build_docs diff --git a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java index fff329a9b9d66..d1190ceb7b672 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/BaseVariableWidthVector.java @@ -169,6 +169,42 @@ public void setInitialCapacity(int valueCount) { offsetAllocationSizeInBytes = (valueCount + 1) * OFFSET_WIDTH; } + /** + * Sets the desired value capacity for the vector. This function doesn't + * allocate any memory for the vector. + * @param valueCount desired number of elements in the vector + * @param density average number of bytes per variable width element + */ + public void setInitialCapacity(int valueCount, double density) { + final long size = (long) (valueCount * density); + if (size < 1) { + throw new IllegalArgumentException("With the provided density and value count, potential capacity of the data buffer is 0"); + } + if (size > MAX_ALLOCATION_SIZE) { + throw new OversizedAllocationException("Requested amount of memory is more than max allowed"); + } + valueAllocationSizeInBytes = (int) size; + validityAllocationSizeInBytes = getValidityBufferSizeFromCount(valueCount); + /* to track the end offset of last data element in vector, we need + * an additional slot in offset buffer. + */ + offsetAllocationSizeInBytes = (valueCount + 1) * OFFSET_WIDTH; + } + + /** + * Get the density of this ListVector + * @return density + */ + public double getDensity() { + if (valueCount == 0) { + return 0.0D; + } + final int startOffset = offsetBuffer.getInt(0); + final int endOffset = offsetBuffer.getInt(valueCount * OFFSET_WIDTH); + final double totalListSize = endOffset - startOffset; + return totalListSize/valueCount; + } + /** * Get the current value capacity for the vector * @return number of elements that vector can hold. diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java index d0a664ac01da2..50ee3a7573efe 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/BaseRepeatedValueVector.java @@ -143,6 +143,38 @@ public void setInitialCapacity(int numRecords) { } } + /** + * Specialized version of setInitialCapacity() for ListVector. This is + * used by some callers when they want to explicitly control and be + * conservative about memory allocated for inner data vector. This is + * very useful when we are working with memory constraints for a query + * and have a fixed amount of memory reserved for the record batch. In + * such cases, we are likely to face OOM or related problems when + * we reserve memory for a record batch with value count x and + * do setInitialCapacity(x) such that each vector allocates only + * what is necessary and not the default amount but the multiplier + * forces the memory requirement to go beyond what was needed. + * + * @param numRecords value count + * @param density density of ListVector. Density is the average size of + * list per position in the List vector. For example, a + * density value of 10 implies each position in the list + * vector has a list of 10 values. + * A density value of 0.1 implies out of 10 positions in + * the list vector, 1 position has a list of size 1 and + * remaining positions are null (no lists) or empty lists. + * This helps in tightly controlling the memory we provision + * for inner data vector. + */ + public void setInitialCapacity(int numRecords, double density) { + offsetAllocationSizeInBytes = (numRecords + 1) * OFFSET_WIDTH; + final int innerValueCapacity = (int)(numRecords * density); + if (innerValueCapacity < 1) { + throw new IllegalArgumentException("With the provided density and value count, potential value capacity for the data vector is 0"); + } + vector.setInitialCapacity(innerValueCapacity); + } + @Override public int getValueCapacity() { final int offsetValueCapacity = Math.max(getOffsetBufferValueCapacity() - 1, 0); diff --git a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java index 8aeeb7e5a2886..b472dae069431 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/complex/ListVector.java @@ -31,12 +31,7 @@ import org.apache.arrow.memory.BaseAllocator; import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.memory.OutOfMemoryException; -import org.apache.arrow.vector.AddOrGetResult; -import org.apache.arrow.vector.BufferBacked; -import org.apache.arrow.vector.FieldVector; -import org.apache.arrow.vector.ValueVector; -import org.apache.arrow.vector.ZeroVector; -import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.*; import org.apache.arrow.vector.complex.impl.ComplexCopier; import org.apache.arrow.vector.complex.impl.UnionListReader; import org.apache.arrow.vector.complex.impl.UnionListWriter; @@ -102,6 +97,54 @@ public void initializeChildrenFromFields(List children) { addOrGetVector.getVector().initializeChildrenFromFields(field.getChildren()); } + @Override + public void setInitialCapacity(int numRecords) { + validityAllocationSizeInBytes = getValidityBufferSizeFromCount(numRecords); + super.setInitialCapacity(numRecords); + } + + /** + * Specialized version of setInitialCapacity() for ListVector. This is + * used by some callers when they want to explicitly control and be + * conservative about memory allocated for inner data vector. This is + * very useful when we are working with memory constraints for a query + * and have a fixed amount of memory reserved for the record batch. In + * such cases, we are likely to face OOM or related problems when + * we reserve memory for a record batch with value count x and + * do setInitialCapacity(x) such that each vector allocates only + * what is necessary and not the default amount but the multiplier + * forces the memory requirement to go beyond what was needed. + * + * @param numRecords value count + * @param density density of ListVector. Density is the average size of + * list per position in the List vector. For example, a + * density value of 10 implies each position in the list + * vector has a list of 10 values. + * A density value of 0.1 implies out of 10 positions in + * the list vector, 1 position has a list of size 1 and + * remaining positions are null (no lists) or empty lists. + * This helps in tightly controlling the memory we provision + * for inner data vector. + */ + public void setInitialCapacity(int numRecords, double density) { + validityAllocationSizeInBytes = getValidityBufferSizeFromCount(numRecords); + super.setInitialCapacity(numRecords, density); + } + + /** + * Get the density of this ListVector + * @return density + */ + public double getDensity() { + if (valueCount == 0) { + return 0.0D; + } + final int startOffset = offsetBuffer.getInt(0); + final int endOffset = offsetBuffer.getInt(valueCount * OFFSET_WIDTH); + final double totalListSize = endOffset - startOffset; + return totalListSize/valueCount; + } + @Override public List getChildrenFromFields() { return singletonList(getDataVector()); @@ -623,7 +666,7 @@ public int getNullCount() { */ @Override public int getValueCapacity() { - return Math.min(getValidityBufferValueCapacity(), super.getValueCapacity()); + return getValidityAndOffsetValueCapacity(); } private int getValidityAndOffsetValueCapacity() { diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java index e2023f4461879..d49a677f67922 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestListVector.java @@ -112,6 +112,9 @@ public void testCopyFrom() throws Exception { result = outVector.getObject(2); resultSet = (ArrayList) result; assertEquals(0, resultSet.size()); + + /* 3+0+0/3 */ + assertEquals(1.0D, inVector.getDensity(), 0); } } @@ -209,6 +212,9 @@ public void testSetLastSetUsage() throws Exception { listVector.setLastSet(3); listVector.setValueCount(10); + /* (3+2+3)/10 */ + assertEquals(0.8D, listVector.getDensity(), 0); + index = 0; offset = offsetBuffer.getInt(index * ListVector.OFFSET_WIDTH); assertEquals(Integer.toString(0), Integer.toString(offset)); @@ -709,6 +715,8 @@ public void testGetBufferAddress() throws Exception { listWriter.bigInt().writeBigInt(300); listWriter.endList(); + listVector.setValueCount(2); + /* check listVector contents */ Object result = listVector.getObject(0); ArrayList resultSet = (ArrayList) result; @@ -739,6 +747,9 @@ public void testGetBufferAddress() throws Exception { assertEquals(2, buffers.size()); assertEquals(bitAddress, buffers.get(0).memoryAddress()); assertEquals(offsetAddress, buffers.get(1).memoryAddress()); + + /* (3+2)/2 */ + assertEquals(2.5, listVector.getDensity(), 0); } } @@ -753,4 +764,61 @@ public void testConsistentChildName() throws Exception { assertTrue(emptyVectorStr.contains(ListVector.DATA_VECTOR_NAME)); } } + + @Test + public void testSetInitialCapacity() { + try (final ListVector vector = ListVector.empty("", allocator)) { + vector.addOrGetVector(FieldType.nullable(MinorType.INT.getType())); + + /** + * use the default multiplier of 5, + * 512 * 5 => 2560 * 4 => 10240 bytes => 16KB => 4096 value capacity. + */ + vector.setInitialCapacity(512); + vector.allocateNew(); + assertEquals(512, vector.getValueCapacity()); + assertEquals(4096, vector.getDataVector().getValueCapacity()); + + /* use density as 4 */ + vector.setInitialCapacity(512, 4); + vector.allocateNew(); + assertEquals(512, vector.getValueCapacity()); + assertEquals(512*4, vector.getDataVector().getValueCapacity()); + + /** + * inner value capacity we pass to data vector is 512 * 0.1 => 51 + * For an int vector this is 204 bytes of memory for data buffer + * and 7 bytes for validity buffer. + * and with power of 2 allocation, we allocate 256 bytes and 8 bytes + * for the data buffer and validity buffer of the inner vector. Thus + * value capacity of inner vector is 64 + */ + vector.setInitialCapacity(512, 0.1); + vector.allocateNew(); + assertEquals(512, vector.getValueCapacity()); + assertEquals(64, vector.getDataVector().getValueCapacity()); + + /** + * inner value capacity we pass to data vector is 512 * 0.01 => 5 + * For an int vector this is 20 bytes of memory for data buffer + * and 1 byte for validity buffer. + * and with power of 2 allocation, we allocate 32 bytes and 1 bytes + * for the data buffer and validity buffer of the inner vector. Thus + * value capacity of inner vector is 8 + */ + vector.setInitialCapacity(512, 0.01); + vector.allocateNew(); + assertEquals(512, vector.getValueCapacity()); + assertEquals(8, vector.getDataVector().getValueCapacity()); + + boolean error = false; + try { + vector.setInitialCapacity(5, 0.1); + } catch (IllegalArgumentException e) { + error = true; + } finally { + assertTrue(error); + } + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java index 601b2062ff698..992bb6264a1cf 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestValueVector.java @@ -1908,4 +1908,40 @@ public static void setBytes(int index, byte[] bytes, VarCharVector vector) { vector.offsetBuffer.setInt((index + 1) * vector.OFFSET_WIDTH, currentOffset + bytes.length); vector.valueBuffer.setBytes(currentOffset, bytes, 0, bytes.length); } + + @Test /* VarCharVector */ + public void testSetInitialCapacity() { + try (final VarCharVector vector = new VarCharVector(EMPTY_SCHEMA_PATH, allocator)) { + + /* use the default 8 data bytes on average per element */ + vector.setInitialCapacity(4096); + vector.allocateNew(); + assertEquals(4096, vector.getValueCapacity()); + assertEquals(4096 * 8, vector.getDataBuffer().capacity()); + + vector.setInitialCapacity(4096, 1); + vector.allocateNew(); + assertEquals(4096, vector.getValueCapacity()); + assertEquals(4096, vector.getDataBuffer().capacity()); + + vector.setInitialCapacity(4096, 0.1); + vector.allocateNew(); + assertEquals(4096, vector.getValueCapacity()); + assertEquals(512, vector.getDataBuffer().capacity()); + + vector.setInitialCapacity(4096, 0.01); + vector.allocateNew(); + assertEquals(4096, vector.getValueCapacity()); + assertEquals(64, vector.getDataBuffer().capacity()); + + boolean error = false; + try { + vector.setInitialCapacity(5, 0.1); + } catch (IllegalArgumentException e) { + error = true; + } finally { + assertTrue(error); + } + } + } } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java index f8edf8904c53e..ca039c52f9715 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestVectorReAlloc.java @@ -104,7 +104,7 @@ public void testListType() { vector.setInitialCapacity(512); vector.allocateNew(); - assertEquals(1023, vector.getValueCapacity()); + assertEquals(512, vector.getValueCapacity()); try { vector.getInnerValueCountAt(2014); @@ -114,7 +114,7 @@ public void testListType() { } vector.reAlloc(); - assertEquals(2047, vector.getValueCapacity()); // note: size - 1 + assertEquals(1024, vector.getValueCapacity()); assertEquals(0, vector.getOffsetBuffer().getInt(2014 * ListVector.OFFSET_WIDTH)); } } diff --git a/python/README-benchmarks.md b/python/README-benchmarks.md new file mode 100644 index 0000000000000..6389665b075d9 --- /dev/null +++ b/python/README-benchmarks.md @@ -0,0 +1,54 @@ + + +# Benchmarks + +The `pyarrow` package comes with a suite of benchmarks meant to +run with [ASV](https://asv.readthedocs.io). You'll need to install +the `asv` package first (`pip install asv`). + +## Running with your local tree + +When developing, the simplest and fastest way to run the benchmark suite +against your local changes is to use the `asv dev` command. This will +use your current Python interpreter and environment. + +## Running with arbitrary revisions + +ASV allows to store results and generate graphs of the benchmarks over +the project's evolution. Doing this requires a bit more massaging +currently. + +First you have to install our ASV fork: + +```shell +pip install git+https://github.com/pitrou/asv.git@issue-547-specify-project-subdir +``` + + + +Then you need to set up a few environment variables: + +```shell +export SETUPTOOLS_SCM_PRETEND_VERSION=0.0.1 +export PYARROW_BUNDLE_ARROW_CPP=1 +``` + +Now you should be ready to run `asv run` or whatever other command +suits your needs. diff --git a/python/asv.conf.json b/python/asv.conf.json index 2a1dd42aba136..150153c8020f9 100644 --- a/python/asv.conf.json +++ b/python/asv.conf.json @@ -28,12 +28,17 @@ // The URL or local path of the source code repository for the // project being benchmarked - "repo": "https://github.com/apache/arrow/", + "repo": "..", + + // The Python project's subdirectory in your repo. If missing or + // the empty string, the project is assumed to be located at the root + // of the repository. + "repo_subdir": "python", // List of branches to benchmark. If not provided, defaults to "master" - // (for git) or "tip" (for mercurial). + // (for git) or "default" (for mercurial). // "branches": ["master"], // for git - // "branches": ["tip"], // for mercurial + // "branches": ["default"], // for mercurial // The DVCS being used. If not set, it will be automatically // determined from "repo" by looking at the protocol in the URL @@ -46,22 +51,72 @@ // If missing or the empty string, the tool will be automatically // determined by looking for tools on the PATH environment // variable. - "environment_type": "virtualenv", + "environment_type": "conda", + "conda_channels": ["conda-forge", "defaults"], // the base URL to show a commit for the project. "show_commit_url": "https://github.com/apache/arrow/commit/", // The Pythons you'd like to test against. If not provided, defaults // to the current version of Python used to run `asv`. - // "pythons": ["2.7", "3.3"], + "pythons": ["3.6"], // The matrix of dependencies to test. Each key is the name of a // package (in PyPI) and the values are version numbers. An empty - // list indicates to just test against the default (latest) - // version. + // list or empty string indicates to just test against the default + // (latest) version. null indicates that the package is to not be + // installed. If the package to be tested is only available from + // PyPi, and the 'environment_type' is conda, then you can preface + // the package name by 'pip+', and the package will be installed via + // pip (with all the conda available packages installed first, + // followed by the pip installed packages). + // // "matrix": { - // "numpy": ["1.6", "1.7"] + // "numpy": ["1.6", "1.7"], + // "six": ["", null], // test with and without six installed + // "pip+emcee": [""], // emcee is only available for install with pip. // }, + "matrix": { + "boost-cpp": [], + "cmake": [], + "cython": [], + "numpy": ["1.14"], + "pandas": ["0.22"], + "pip+setuptools_scm": [], + }, + + // Combinations of libraries/python versions can be excluded/included + // from the set to test. Each entry is a dictionary containing additional + // key-value pairs to include/exclude. + // + // An exclude entry excludes entries where all values match. The + // values are regexps that should match the whole string. + // + // An include entry adds an environment. Only the packages listed + // are installed. The 'python' key is required. The exclude rules + // do not apply to includes. + // + // In addition to package names, the following keys are available: + // + // - python + // Python version, as in the *pythons* variable above. + // - environment_type + // Environment type, as above. + // - sys_platform + // Platform, as in sys.platform. Possible values for the common + // cases: 'linux2', 'win32', 'cygwin', 'darwin'. + // + // "exclude": [ + // {"python": "3.2", "sys_platform": "win32"}, // skip py3.2 on windows + // {"environment_type": "conda", "six": null}, // don't run without six on conda + // ], + // + // "include": [ + // // additional env for python2.7 + // {"python": "2.7", "numpy": "1.8"}, + // // additional env if run on windows+conda + // {"platform": "win32", "environment_type": "conda", "python": "2.7", "libpython": ""}, + // ], // The directory (relative to the current directory) that benchmarks are // stored in. If not provided, defaults to "benchmarks" @@ -71,7 +126,6 @@ // environments in. If not provided, defaults to "env" "env_dir": ".asv/env", - // The directory (relative to the current directory) that raw benchmark // results are stored in. If not provided, defaults to "results". "results_dir": ".asv/results", @@ -86,5 +140,29 @@ // `asv` will cache wheels of the recent builds in each // environment, making them faster to install next time. This is // number of builds to keep, per environment. - // "wheel_cache_size": 0 + // "wheel_cache_size": 0, + + // The commits after which the regression search in `asv publish` + // should start looking for regressions. Dictionary whose keys are + // regexps matching to benchmark names, and values corresponding to + // the commit (exclusive) after which to start looking for + // regressions. The default is to start from the first commit + // with results. If the commit is `null`, regression detection is + // skipped for the matching benchmark. + // + // "regressions_first_commits": { + // "some_benchmark": "352cdf", // Consider regressions only after this commit + // "another_benchmark": null, // Skip regression detection altogether + // } + + // The thresholds for relative change in results, after which `asv + // publish` starts reporting regressions. Dictionary of the same + // form as in ``regressions_first_commits``, with values + // indicating the thresholds. If multiple entries match, the + // maximum is taken. If no entry matches, the default is 5%. + // + // "regressions_thresholds": { + // "some_benchmark": 0.01, // Threshold of 1% + // "another_benchmark": 0.5, // Threshold of 50% + // } } diff --git a/c_glib/doc/reference/xml/meson.build b/python/benchmarks/array_ops.py similarity index 55% rename from c_glib/doc/reference/xml/meson.build rename to python/benchmarks/array_ops.py index 5b65042764fee..70ee7f1e1fcfc 100644 --- a/c_glib/doc/reference/xml/meson.build +++ b/python/benchmarks/array_ops.py @@ -1,5 +1,3 @@ -# -*- indent-tabs-mode: nil -*- -# # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -17,15 +15,21 @@ # specific language governing permissions and limitations # under the License. -entities_conf = configuration_data() -entities_conf.set('package', meson.project_name()) -entities_conf.set('package_bugreport', - 'https://issues.apache.org/jira/browse/ARROW') -entities_conf.set('package_name', meson.project_name()) -entities_conf.set('package_string', - ' '.join([meson.project_name(), version])) -entities_conf.set('package_url', 'https://arrow.apache.org/') -entities_conf.set('package_version', version) -configure_file(input: 'gtkdocentities.ent.in', - output: 'gtkdocentities.ent', - configuration: entities_conf) +import numpy as np +import pyarrow as pa + + +class ScalarAccess(object): + n = 10 ** 5 + + def setUp(self): + self._array = pa.array(list(range(self.n)), type=pa.int64()) + self._array_items = list(self._array) + + def time_getitem(self): + for i in range(self.n): + self._array[i] + + def time_as_py(self): + for item in self._array_items: + item.as_py() diff --git a/python/benchmarks/common.py b/python/benchmarks/common.py new file mode 100644 index 0000000000000..7dd42fde5abe1 --- /dev/null +++ b/python/benchmarks/common.py @@ -0,0 +1,84 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import codecs +import os +import sys +import unicodedata + +import numpy as np + + +def _multiplicate_sequence(base, target_size): + q, r = divmod(target_size, len(base)) + return [base] * q + [base[:r]] + + +def get_random_bytes(n): + rnd = np.random.RandomState(42) + # Computing a huge random bytestring can be costly, so we get at most + # 100KB and duplicate the result as needed + base_size = 100003 + q, r = divmod(n, base_size) + if q == 0: + result = rnd.bytes(r) + else: + base = rnd.bytes(base_size) + result = b''.join(_multiplicate_sequence(base, n)) + assert len(result) == n + return result + + +def get_random_ascii(n): + arr = np.frombuffer(get_random_bytes(n), dtype=np.int8) & 0x7f + result, _ = codecs.ascii_decode(arr) + assert isinstance(result, str) + assert len(result) == n + return result + + +def _random_unicode_letters(n): + """ + Generate a string of random unicode letters (slow). + """ + def _get_more_candidates(): + return rnd.randint(0, sys.maxunicode, size=n).tolist() + + rnd = np.random.RandomState(42) + out = [] + candidates = [] + + while len(out) < n: + if not candidates: + candidates = _get_more_candidates() + ch = chr(candidates.pop()) + # XXX Do we actually care that the code points are valid? + if unicodedata.category(ch)[0] == 'L': + out.append(ch) + return out + + +_1024_random_unicode_letters = _random_unicode_letters(1024) + + +def get_random_unicode(n): + indices = np.frombuffer(get_random_bytes(n * 2), dtype=np.int16) & 1023 + unicode_arr = np.array(_1024_random_unicode_letters)[indices] + + result = ''.join(unicode_arr.tolist()) + assert len(result) == n, (len(result), len(unicode_arr)) + return result diff --git a/python/benchmarks/convert_builtins.py b/python/benchmarks/convert_builtins.py new file mode 100644 index 0000000000000..92b2b850f2a0a --- /dev/null +++ b/python/benchmarks/convert_builtins.py @@ -0,0 +1,295 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from functools import partial +import itertools + +import numpy as np +import pyarrow as pa + +from . import common + + +DEFAULT_NONE_PROB = 0.3 + + +# TODO: +# - test dates and times +# - test decimals + +class BuiltinsGenerator(object): + + def __init__(self, seed=42): + self.rnd = np.random.RandomState(seed) + + def sprinkle_nones(self, lst, prob): + """ + Sprinkle None entries in list *lst* with likelihood *prob*. + """ + for i, p in enumerate(self.rnd.random_sample(size=len(lst))): + if p < prob: + lst[i] = None + + def generate_int_list(self, n, none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of Python ints with *none_prob* probability of + an entry being None. + """ + data = list(range(n)) + self.sprinkle_nones(data, none_prob) + return data + + def generate_float_list(self, n, none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of Python floats with *none_prob* probability of + an entry being None. + """ + # Make sure we get Python floats, not np.float64 + data = list(map(float, self.rnd.uniform(0.0, 1.0, n))) + assert len(data) == n + self.sprinkle_nones(data, none_prob) + return data + + def generate_bool_list(self, n, none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of Python bools with *none_prob* probability of + an entry being None. + """ + # Make sure we get Python bools, not np.bool_ + data = [bool(x >= 0.5) for x in self.rnd.uniform(0.0, 1.0, n)] + assert len(data) == n + self.sprinkle_nones(data, none_prob) + return data + + def _generate_varying_sequences(self, random_factory, n, min_size, max_size, none_prob): + """ + Generate a list of *n* sequences of varying size between *min_size* + and *max_size*, with *none_prob* probability of an entry being None. + The base material for each sequence is obtained by calling + `random_factory()` + """ + base_size = 10000 + base = random_factory(base_size + max_size) + data = [] + for i in range(n): + off = self.rnd.randint(base_size) + if min_size == max_size: + size = min_size + else: + size = self.rnd.randint(min_size, max_size + 1) + data.append(base[off:off + size]) + self.sprinkle_nones(data, none_prob) + assert len(data) == n + return data + + def generate_fixed_binary_list(self, n, size, none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of bytestrings with a fixed *size*. + """ + return self._generate_varying_sequences(common.get_random_bytes, n, + size, size, none_prob) + + + def generate_varying_binary_list(self, n, min_size, max_size, + none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of bytestrings with a random size between + *min_size* and *max_size*. + """ + return self._generate_varying_sequences(common.get_random_bytes, n, + min_size, max_size, none_prob) + + + def generate_ascii_string_list(self, n, min_size, max_size, + none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of ASCII strings with a random size between + *min_size* and *max_size*. + """ + return self._generate_varying_sequences(common.get_random_ascii, n, + min_size, max_size, none_prob) + + + def generate_unicode_string_list(self, n, min_size, max_size, + none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of unicode strings with a random size between + *min_size* and *max_size*. + """ + return self._generate_varying_sequences(common.get_random_unicode, n, + min_size, max_size, none_prob) + + + def generate_int_list_list(self, n, min_size, max_size, + none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of lists of Python ints with a random size between + *min_size* and *max_size*. + """ + return self._generate_varying_sequences( + partial(self.generate_int_list, none_prob=none_prob), + n, min_size, max_size, none_prob) + + + def generate_dict_list(self, n, none_prob=DEFAULT_NONE_PROB): + """ + Generate a list of dicts with a random size between *min_size* and + *max_size*. + Each dict has the form `{'u': int value, 'v': float value, 'w': bool value}` + """ + ints = self.generate_int_list(n, none_prob=none_prob) + floats = self.generate_float_list(n, none_prob=none_prob) + bools = self.generate_bool_list(n, none_prob=none_prob) + dicts = [] + # Keep half the Nones, omit the other half + keep_nones = itertools.cycle([True, False]) + for u, v, w in zip(ints, floats, bools): + d = {} + if u is not None or next(keep_nones): + d['u'] = u + if v is not None or next(keep_nones): + d['v'] = v + if w is not None or next(keep_nones): + d['w'] = w + dicts.append(d) + self.sprinkle_nones(dicts, none_prob) + assert len(dicts) == n + return dicts + + def get_type_and_builtins(self, n, type_name): + """ + Return a `(arrow type, list)` tuple where the arrow type + corresponds to the given logical *type_name*, and the list + is a list of *n* random-generated Python objects compatible + with the arrow type. + """ + size = None + + if type_name in ('bool', 'ascii', 'unicode', 'int64 list', 'struct'): + kind = type_name + elif type_name.startswith(('int', 'uint')): + kind = 'int' + elif type_name.startswith('float'): + kind = 'float' + elif type_name == 'binary': + kind = 'varying binary' + elif type_name.startswith('binary'): + kind = 'fixed binary' + size = int(type_name[6:]) + assert size > 0 + else: + raise ValueError("unrecognized type %r" % (type_name,)) + + if kind in ('int', 'float'): + ty = getattr(pa, type_name)() + elif kind == 'bool': + ty = pa.bool_() + elif kind == 'fixed binary': + ty = pa.binary(size) + elif kind == 'varying binary': + ty = pa.binary() + elif kind in ('ascii', 'unicode'): + ty = pa.string() + elif kind == 'int64 list': + ty = pa.list_(pa.int64()) + elif kind == 'struct': + ty = pa.struct([pa.field('u', pa.int64()), + pa.field('v', pa.float64()), + pa.field('w', pa.bool_())]) + + factories = { + 'int': self.generate_int_list, + 'float': self.generate_float_list, + 'bool': self.generate_bool_list, + 'fixed binary': partial(self.generate_fixed_binary_list, + size=size), + 'varying binary': partial(self.generate_varying_binary_list, + min_size=3, max_size=40), + 'ascii': partial(self.generate_ascii_string_list, + min_size=3, max_size=40), + 'unicode': partial(self.generate_unicode_string_list, + min_size=3, max_size=40), + 'int64 list': partial(self.generate_int_list_list, + min_size=0, max_size=20), + 'struct': self.generate_dict_list, + } + data = factories[kind](n) + return ty, data + + +class ConvertPyListToArray(object): + """ + Benchmark pa.array(list of values, type=...) + """ + size = 10 ** 5 + types = ('int32', 'uint32', 'int64', 'uint64', + 'float32', 'float64', 'bool', + 'binary', 'binary10', 'ascii', 'unicode', + 'int64 list', 'struct') + + param_names = ['type'] + params = [types] + + def setup(self, type_name): + gen = BuiltinsGenerator() + self.ty, self.data = gen.get_type_and_builtins(self.size, type_name) + + def time_convert(self, *args): + pa.array(self.data, type=self.ty) + + +class InferPyListToArray(object): + """ + Benchmark pa.array(list of values) with type inference + """ + size = 10 ** 5 + types = ('int64', 'float64', 'bool', 'binary', 'ascii', 'unicode', + 'int64 list') + # TODO add 'struct' when supported + + param_names = ['type'] + params = [types] + + def setup(self, type_name): + gen = BuiltinsGenerator() + self.ty, self.data = gen.get_type_and_builtins(self.size, type_name) + + def time_infer(self, *args): + arr = pa.array(self.data) + assert arr.type == self.ty + + +class ConvertArrayToPyList(object): + """ + Benchmark pa.array.to_pylist() + """ + size = 10 ** 5 + types = ('int32', 'uint32', 'int64', 'uint64', + 'float32', 'float64', 'bool', + 'binary', 'binary10', 'ascii', 'unicode', + 'int64 list', 'struct') + + param_names = ['type'] + params = [types] + + def setup(self, type_name): + gen = BuiltinsGenerator() + self.ty, self.data = gen.get_type_and_builtins(self.size, type_name) + self.arr = pa.array(self.data, type=self.ty) + + def time_convert(self, *args): + self.arr.to_pylist() diff --git a/python/benchmarks/array.py b/python/benchmarks/convert_pandas.py similarity index 59% rename from python/benchmarks/array.py rename to python/benchmarks/convert_pandas.py index e22c0f7fc9e70..c4a7a59cb77dc 100644 --- a/python/benchmarks/array.py +++ b/python/benchmarks/convert_pandas.py @@ -17,21 +17,7 @@ import numpy as np import pandas as pd -import pyarrow as A - - -class PyListConversions(object): - param_names = ('size',) - params = (1, 10 ** 5, 10 ** 6, 10 ** 7) - - def setup(self, n): - self.data = list(range(n)) - - def time_from_pylist(self, n): - A.from_pylist(self.data) - - def peakmem_from_pylist(self, n): - A.from_pylist(self.data) +import pyarrow as pa class PandasConversionsBase(object): @@ -46,37 +32,19 @@ def setup(self, n, dtype): class PandasConversionsToArrow(PandasConversionsBase): param_names = ('size', 'dtype') - params = ((1, 10 ** 5, 10 ** 6, 10 ** 7), ('int64', 'float64', 'float64_nans', 'str')) + params = ((10, 10 ** 6), ('int64', 'float64', 'float64_nans', 'str')) def time_from_series(self, n, dtype): - A.Table.from_pandas(self.data) - - def peakmem_from_series(self, n, dtype): - A.Table.from_pandas(self.data) + pa.Table.from_pandas(self.data) class PandasConversionsFromArrow(PandasConversionsBase): param_names = ('size', 'dtype') - params = ((1, 10 ** 5, 10 ** 6, 10 ** 7), ('int64', 'float64', 'float64_nans', 'str')) + params = ((10, 10 ** 6), ('int64', 'float64', 'float64_nans', 'str')) def setup(self, n, dtype): super(PandasConversionsFromArrow, self).setup(n, dtype) - self.arrow_data = A.Table.from_pandas(self.data) + self.arrow_data = pa.Table.from_pandas(self.data) def time_to_series(self, n, dtype): self.arrow_data.to_pandas() - - def peakmem_to_series(self, n, dtype): - self.arrow_data.to_pandas() - - -class ScalarAccess(object): - param_names = ('size',) - params = (1, 10 ** 5, 10 ** 6, 10 ** 7) - - def setUp(self, n): - self._array = A.from_pylist(list(range(n))) - - def time_as_py(self, n): - for i in range(n): - self._array[i].as_py() diff --git a/python/doc/source/development.rst b/python/doc/source/development.rst index 01844fa18d133..af93d8d1a52c4 100644 --- a/python/doc/source/development.rst +++ b/python/doc/source/development.rst @@ -331,3 +331,76 @@ Getting ``python-test.exe`` to run is a bit tricky because your set PYTHONPATH=%CONDA_ENV%\Lib;%CONDA_ENV%\Lib\site-packages;%CONDA_ENV%\python35.zip;%CONDA_ENV%\DLLs;%CONDA_ENV% Now ``python-test.exe`` or simply ``ctest`` (to run all tests) should work. + +Nightly Builds of `arrow-cpp`, `parquet-cpp`, and `pyarrow` for Linux +--------------------------------------------------------------------- + +Nightly builds of Linux conda packages for ``arrow-cpp``, ``parquet-cpp``, and +``pyarrow`` can be automated using an open source tool called `scourge +`_. + +``scourge`` is new, so please report any feature requests or bugs to the +`scourge issue tracker `_. + +To get scourge you need to clone the source and install it in development mode. + +To setup your own nightly builds: + +#. Clone and install scourge +#. Create a script that calls scourge +#. Run that script as a cronjob once per day + +First, clone and install scourge (you also need to `install docker +`): + + +.. code:: sh + + git clone https://github.com/cpcloud/scourge + cd scourge + python setup.py develop + which scourge + +Second, create a shell script that calls scourge: + +.. code:: sh + + function build() { + # make sure we got a working directory + workingdir="${1}" + [ -z "${workingdir}" ] && echo "Must provide a working directory" && exit 1 + scourge="/path/to/scourge" + + # get the hash of master for building parquet + PARQUET_ARROW_VERSION="$("${scourge}" sha apache/arrow master)" + + # setup the build for each package + "${scourge}" init arrow-cpp@master parquet-cpp@master pyarrow@master + + # build the packages with some constraints (the -c arguments) + # -e sets environment variables on a per package basis + "${scourge}" build \ + -e parquet-cpp:PARQUET_ARROW_VERSION="${PARQUET_ARROW_VERSION}" \ + -c "python >=2.7,<3|>=3.5" \ + -c "numpy >= 1.11" \ + -c "r-base >=3.3.2" + } + + workingdir="$(date +'%Y%m%d_%H_%M_%S')" + mkdir -p "${workingdir}" + build "${workingdir}" > "${workingdir}"/scourge.log 2>&1 + +Third, run that script as a cronjob once per day: + +.. code:: sh + + crontab -e + +then in the scratch file that's opened: + +.. code:: sh + + @daily /path/to/the/above/script.sh + +The build artifacts (conda packages) will be located in +``${workingdir}/artifacts/linux-64``. diff --git a/python/manylinux1/Dockerfile-x86_64 b/python/manylinux1/Dockerfile-x86_64 index 919a32be715b0..9c00e7ea256c9 100644 --- a/python/manylinux1/Dockerfile-x86_64 +++ b/python/manylinux1/Dockerfile-x86_64 @@ -14,7 +14,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -FROM quay.io/xhochy/arrow_manylinux1_x86_64_base:latest +FROM quay.io/xhochy/arrow_manylinux1_x86_64_base:ARROW-2048 ADD arrow /arrow WORKDIR /arrow/cpp diff --git a/python/manylinux1/Dockerfile-x86_64_base b/python/manylinux1/Dockerfile-x86_64_base index 0160aa4eea509..ec7893080f65b 100644 --- a/python/manylinux1/Dockerfile-x86_64_base +++ b/python/manylinux1/Dockerfile-x86_64_base @@ -42,6 +42,9 @@ ADD scripts/build_flatbuffers.sh / RUN /build_flatbuffers.sh ENV FLATBUFFERS_HOME /usr +ADD scripts/build_bison.sh / +RUN /build_bison.sh + ADD scripts/build_thrift.sh / RUN /build_thrift.sh ENV THRIFT_HOME /usr diff --git a/c_glib/doc/reference/xml/Makefile.am b/python/manylinux1/scripts/build_bison.sh old mode 100644 new mode 100755 similarity index 79% rename from c_glib/doc/reference/xml/Makefile.am rename to python/manylinux1/scripts/build_bison.sh index 833cfddc69078..29cc0be6adf6c --- a/c_glib/doc/reference/xml/Makefile.am +++ b/python/manylinux1/scripts/build_bison.sh @@ -1,3 +1,4 @@ +#!/bin/bash -ex # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information @@ -15,6 +16,11 @@ # specific language governing permissions and limitations # under the License. -EXTRA_DIST = \ - gtkdocentities.ent.in \ - meson.build +wget http://ftp.gnu.org/gnu/bison/bison-3.0.4.tar.gz +tar xf bison-3.0.4.tar.gz +pushd bison-3.0.4 +./configure --prefix=/usr +make -j4 +make install +popd +rm -rf bison-3.0.4 bison-3.0.4.tar.gz diff --git a/python/manylinux1/scripts/build_thrift.sh b/python/manylinux1/scripts/build_thrift.sh index 28aa75b7413de..aaec4ad6bad41 100755 --- a/python/manylinux1/scripts/build_thrift.sh +++ b/python/manylinux1/scripts/build_thrift.sh @@ -16,7 +16,7 @@ # specific language governing permissions and limitations # under the License. -export THRIFT_VERSION=0.10.0 +export THRIFT_VERSION=0.11.0 wget http://archive.apache.org/dist/thrift/${THRIFT_VERSION}/thrift-${THRIFT_VERSION}.tar.gz tar xf thrift-${THRIFT_VERSION}.tar.gz pushd thrift-${THRIFT_VERSION} diff --git a/python/pyarrow/__init__.py b/python/pyarrow/__init__.py index a245fe6796023..8b3cba92414f8 100644 --- a/python/pyarrow/__init__.py +++ b/python/pyarrow/__init__.py @@ -124,9 +124,10 @@ localfs = LocalFileSystem.get_instance() -from pyarrow.serialization import (_default_serialization_context, +from pyarrow.serialization import (default_serialization_context, pandas_serialization_context, - register_default_serialization_handlers) + register_default_serialization_handlers, + register_torch_serialization_handlers) import pyarrow.types as types diff --git a/python/pyarrow/_orc.pxd b/python/pyarrow/_orc.pxd index 411691510423c..c07a19442b577 100644 --- a/python/pyarrow/_orc.pxd +++ b/python/pyarrow/_orc.pxd @@ -29,9 +29,10 @@ from pyarrow.includes.libarrow cimport (CArray, CSchema, CStatus, TimeUnit) -cdef extern from "arrow/adapters/orc/adapter.h" namespace "arrow::adapters::orc" nogil: - cdef cppclass ORCFileReader: +cdef extern from "arrow/adapters/orc/adapter.h" \ + namespace "arrow::adapters::orc" nogil: + cdef cppclass ORCFileReader: @staticmethod CStatus Open(const shared_ptr[RandomAccessFile]& file, CMemoryPool* pool, @@ -40,7 +41,8 @@ cdef extern from "arrow/adapters/orc/adapter.h" namespace "arrow::adapters::orc" CStatus ReadSchema(shared_ptr[CSchema]* out) CStatus ReadStripe(int64_t stripe, shared_ptr[CRecordBatch]* out) - CStatus ReadStripe(int64_t stripe, std_vector[int], shared_ptr[CRecordBatch]* out) + CStatus ReadStripe(int64_t stripe, std_vector[int], + shared_ptr[CRecordBatch]* out) CStatus Read(shared_ptr[CTable]* out) CStatus Read(std_vector[int], shared_ptr[CTable]* out) diff --git a/python/pyarrow/_orc.pyx b/python/pyarrow/_orc.pyx index 7ff4bac6dc95f..cf04f48a32319 100644 --- a/python/pyarrow/_orc.pyx +++ b/python/pyarrow/_orc.pyx @@ -50,7 +50,7 @@ cdef class ORCReader: get_reader(source, &rd_handle) with nogil: check_status(ORCFileReader.Open(rd_handle, self.allocator, - &self.reader)) + &self.reader)) def schema(self): """ @@ -69,10 +69,10 @@ cdef class ORCReader: return pyarrow_wrap_schema(sp_arrow_schema) def nrows(self): - return deref(self.reader).NumberOfRows(); + return deref(self.reader).NumberOfRows() def nstripes(self): - return deref(self.reader).NumberOfStripes(); + return deref(self.reader).NumberOfStripes() def read_stripe(self, n, include_indices=None): cdef: @@ -85,11 +85,13 @@ cdef class ORCReader: if include_indices is None: with nogil: - check_status(deref(self.reader).ReadStripe(stripe, &sp_record_batch)) + (check_status(deref(self.reader) + .ReadStripe(stripe, &sp_record_batch))) else: indices = include_indices with nogil: - check_status(deref(self.reader).ReadStripe(stripe, indices, &sp_record_batch)) + (check_status(deref(self.reader) + .ReadStripe(stripe, indices, &sp_record_batch))) batch = RecordBatch() batch.init(sp_record_batch) diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index cca9425881b00..caeefd2ff4f6a 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -21,14 +21,21 @@ cdef _sequence_to_array(object sequence, object size, DataType type, cdef shared_ptr[CArray] out cdef int64_t c_size if type is None: - with nogil: - check_status(ConvertPySequence(sequence, pool, &out)) + if size is None: + with nogil: + check_status(ConvertPySequence(sequence, pool, &out)) + else: + c_size = size + with nogil: + check_status( + ConvertPySequence(sequence, c_size, pool, &out) + ) else: if size is None: with nogil: check_status( ConvertPySequence( - sequence, pool, &out, type.sp_type + sequence, type.sp_type, pool, &out, ) ) else: @@ -36,7 +43,7 @@ cdef _sequence_to_array(object sequence, object size, DataType type, with nogil: check_status( ConvertPySequence( - sequence, pool, &out, type.sp_type, c_size + sequence, c_size, type.sp_type, pool, &out, ) ) diff --git a/python/pyarrow/hdfs.py b/python/pyarrow/hdfs.py index 3c9d04188a6ca..3f2014b65c097 100644 --- a/python/pyarrow/hdfs.py +++ b/python/pyarrow/hdfs.py @@ -36,6 +36,10 @@ def __init__(self, host="default", port=0, user=None, kerb_ticket=None, self._connect(host, port, user, kerb_ticket, driver) + def __reduce__(self): + return (HadoopFileSystem, (self.host, self.port, self.user, + self.kerb_ticket, self.driver)) + @implements(FileSystem.isdir) def isdir(self, path): return super(HadoopFileSystem, self).isdir(path) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index 91bc96dc63f89..2e83f0701ce2e 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -852,13 +852,14 @@ cdef extern from "arrow/python/api.h" namespace "arrow::py" nogil: shared_ptr[CDataType] GetTimestampType(TimeUnit unit) CStatus ConvertPySequence(object obj, CMemoryPool* pool, shared_ptr[CArray]* out) - CStatus ConvertPySequence(object obj, CMemoryPool* pool, - shared_ptr[CArray]* out, - const shared_ptr[CDataType]& type) - CStatus ConvertPySequence(object obj, CMemoryPool* pool, - shared_ptr[CArray]* out, + CStatus ConvertPySequence(object obj, const shared_ptr[CDataType]& type, + CMemoryPool* pool, shared_ptr[CArray]* out) + CStatus ConvertPySequence(object obj, int64_t size, CMemoryPool* pool, + shared_ptr[CArray]* out) + CStatus ConvertPySequence(object obj, int64_t size, const shared_ptr[CDataType]& type, - int64_t size) + CMemoryPool* pool, + shared_ptr[CArray]* out) CStatus NumPyDtypeToArrow(object dtype, shared_ptr[CDataType]* type) diff --git a/python/pyarrow/io-hdfs.pxi b/python/pyarrow/io-hdfs.pxi index e653813235862..83b14b687830d 100644 --- a/python/pyarrow/io-hdfs.pxi +++ b/python/pyarrow/io-hdfs.pxi @@ -59,29 +59,41 @@ cdef class HadoopFileSystem: cdef readonly: bint is_open - - def __cinit__(self): - pass + str host + str user + str kerb_ticket + str driver + int port def _connect(self, host, port, user, kerb_ticket, driver): cdef HdfsConnectionConfig conf if host is not None: conf.host = tobytes(host) + self.host = host + conf.port = port + self.port = port + if user is not None: conf.user = tobytes(user) + self.user = user + if kerb_ticket is not None: conf.kerb_ticket = tobytes(kerb_ticket) + self.kerb_ticket = kerb_ticket if driver == 'libhdfs': with nogil: check_status(HaveLibHdfs()) conf.driver = HdfsDriver_LIBHDFS - else: + elif driver == 'libhdfs3': with nogil: check_status(HaveLibHdfs3()) conf.driver = HdfsDriver_LIBHDFS3 + else: + raise ValueError("unknown driver: %r" % driver) + self.driver = driver with nogil: check_status(CHadoopFileSystem.Connect(&conf, &self.client)) @@ -401,9 +413,7 @@ cdef class HadoopFileSystem: &wr_handle)) out.wr_file = wr_handle - - out.is_readable = False - out.is_writeable = 1 + out.is_writable = True else: with nogil: check_status(self.client.get() @@ -411,7 +421,6 @@ cdef class HadoopFileSystem: out.rd_file = rd_handle out.is_readable = True - out.is_writeable = 0 if c_buffer_size == 0: c_buffer_size = 2 ** 16 @@ -419,7 +428,7 @@ cdef class HadoopFileSystem: out.mode = mode out.buffer_size = c_buffer_size out.parent = _HdfsFileNanny(self, out) - out.is_open = True + out.closed = False out.own_file = True return out diff --git a/python/pyarrow/io.pxi b/python/pyarrow/io.pxi index 5449872ff101f..bd508cf57ee8d 100644 --- a/python/pyarrow/io.pxi +++ b/python/pyarrow/io.pxi @@ -39,13 +39,14 @@ cdef extern from "Python.h": cdef class NativeFile: - def __cinit__(self): - self.is_open = False + self.closed = True self.own_file = False + self.is_readable = False + self.is_writable = False def __dealloc__(self): - if self.is_open and self.own_file: + if self.own_file and not self.closed: self.close() def __enter__(self): @@ -65,45 +66,63 @@ cdef class NativeFile: def __get__(self): # Emulate built-in file modes - if self.is_readable and self.is_writeable: + if self.is_readable and self.is_writable: return 'rb+' elif self.is_readable: return 'rb' - elif self.is_writeable: + elif self.is_writable: return 'wb' else: raise ValueError('File object is malformed, has no mode') + def readable(self): + self._assert_open() + return self.is_readable + + def writable(self): + self._assert_open() + return self.is_writable + + def seekable(self): + self._assert_open() + return self.is_readable + def close(self): - if self.is_open: + if not self.closed: with nogil: if self.is_readable: check_status(self.rd_file.get().Close()) else: check_status(self.wr_file.get().Close()) - self.is_open = False + self.closed = True + + def flush(self): + """Flush the buffer stream, if applicable. + + No-op to match the IOBase interface.""" + self._assert_open() cdef read_handle(self, shared_ptr[RandomAccessFile]* file): self._assert_readable() file[0] = self.rd_file cdef write_handle(self, shared_ptr[OutputStream]* file): - self._assert_writeable() + self._assert_writable() file[0] = self.wr_file + def _assert_open(self): + if self.closed: + raise ValueError("I/O operation on closed file") + def _assert_readable(self): + self._assert_open() if not self.is_readable: raise IOError("only valid on readonly files") - if not self.is_open: - raise IOError("file not open") - - def _assert_writeable(self): - if not self.is_writeable: - raise IOError("only valid on writeable files") - - if not self.is_open: - raise IOError("file not open") + def _assert_writable(self): + self._assert_open() + if not self.is_writable: + raise IOError("only valid on writable files") def size(self): """ @@ -120,6 +139,7 @@ cdef class NativeFile: Return current stream position """ cdef int64_t position + self._assert_open() with nogil: if self.is_readable: check_status(self.rd_file.get().Tell(&position)) @@ -174,7 +194,7 @@ cdef class NativeFile: Write byte from any object implementing buffer protocol (bytes, bytearray, ndarray, pyarrow.Buffer) """ - self._assert_writeable() + self._assert_writable() if isinstance(data, six.string_types): data = tobytes(data) @@ -223,6 +243,12 @@ cdef class NativeFile: return PyObject_to_object(obj) + def read1(self, nbytes=None): + """Read and return up to n bytes. + + Alias for read, needed to match the IOBase interface.""" + return self.read(nbytes=None) + def read_buffer(self, nbytes=None): cdef: int64_t c_nbytes @@ -332,7 +358,7 @@ cdef class NativeFile: Pipe file-like object to file """ write_queue = Queue(50) - self._assert_writeable() + self._assert_writable() buffer_size = buffer_size or DEFAULT_BUFFER_SIZE @@ -389,16 +415,14 @@ cdef class PythonFile(NativeFile): if mode.startswith('w'): self.wr_file.reset(new PyOutputStream(handle)) - self.is_readable = 0 - self.is_writeable = 1 + self.is_writable = True elif mode.startswith('r'): self.rd_file.reset(new PyReadableFile(handle)) - self.is_readable = 1 - self.is_writeable = 0 + self.is_readable = True else: raise ValueError('Invalid file mode: {0}'.format(mode)) - self.is_open = True + self.closed = False cdef class MemoryMappedFile(NativeFile): @@ -408,11 +432,6 @@ cdef class MemoryMappedFile(NativeFile): cdef: object path - def __cinit__(self): - self.is_open = False - self.is_readable = 0 - self.is_writeable = 0 - @staticmethod def create(path, size): cdef: @@ -425,11 +444,11 @@ cdef class MemoryMappedFile(NativeFile): cdef MemoryMappedFile result = MemoryMappedFile() result.path = path - result.is_readable = 1 - result.is_writeable = 1 + result.is_readable = True + result.is_writable = True result.wr_file = handle result.rd_file = handle - result.is_open = True + result.closed = False return result @@ -443,14 +462,14 @@ cdef class MemoryMappedFile(NativeFile): if mode in ('r', 'rb'): c_mode = FileMode_READ - self.is_readable = 1 + self.is_readable = True elif mode in ('w', 'wb'): c_mode = FileMode_WRITE - self.is_writeable = 1 + self.is_writable = True elif mode in ('r+', 'r+b', 'rb+'): c_mode = FileMode_READWRITE - self.is_readable = 1 - self.is_writeable = 1 + self.is_readable = True + self.is_writable = True else: raise ValueError('Invalid file mode: {0}'.format(mode)) @@ -459,7 +478,7 @@ cdef class MemoryMappedFile(NativeFile): self.wr_file = handle self.rd_file = handle - self.is_open = True + self.closed = False def memory_map(path, mode='r'): @@ -483,7 +502,7 @@ def memory_map(path, mode='r'): def create_memory_map(path, size): """ Create memory map at indicated path of the given size, return open - writeable file object + writable file object Parameters ---------- @@ -512,16 +531,14 @@ cdef class OSFile(NativeFile): shared_ptr[Readable] handle c_string c_path = encode_file_path(path) - self.is_readable = self.is_writeable = 0 - if mode in ('r', 'rb'): self._open_readable(c_path, maybe_unbox_memory_pool(memory_pool)) elif mode in ('w', 'wb'): - self._open_writeable(c_path) + self._open_writable(c_path) else: raise ValueError('Invalid file mode: {0}'.format(mode)) - self.is_open = True + self.closed = False cdef _open_readable(self, c_string path, CMemoryPool* pool): cdef shared_ptr[ReadableFile] handle @@ -529,15 +546,15 @@ cdef class OSFile(NativeFile): with nogil: check_status(ReadableFile.Open(path, pool, &handle)) - self.is_readable = 1 + self.is_readable = True self.rd_file = handle - cdef _open_writeable(self, c_string path): + cdef _open_writable(self, c_string path): cdef shared_ptr[FileOutputStream] handle with nogil: check_status(FileOutputStream.Open(path, &handle)) - self.is_writeable = 1 + self.is_writable = True self.wr_file = handle @@ -545,9 +562,8 @@ cdef class FixedSizeBufferWriter(NativeFile): def __cinit__(self, Buffer buffer): self.wr_file.reset(new CFixedSizeBufferWriter(buffer.buffer)) - self.is_readable = 0 - self.is_writeable = 1 - self.is_open = True + self.is_writable = True + self.closed = False def set_memcopy_threads(self, int num_threads): cdef CFixedSizeBufferWriter* writer = \ @@ -737,14 +753,13 @@ cdef class BufferOutputStream(NativeFile): self.buffer = _allocate_buffer(maybe_unbox_memory_pool(memory_pool)) self.wr_file.reset(new CBufferOutputStream( self.buffer)) - self.is_readable = 0 - self.is_writeable = 1 - self.is_open = True + self.is_writable = True + self.closed = False def get_result(self): with nogil: check_status(self.wr_file.get().Close()) - self.is_open = False + self.closed = True return pyarrow_wrap_buffer( self.buffer) @@ -752,9 +767,8 @@ cdef class MockOutputStream(NativeFile): def __cinit__(self): self.wr_file.reset(new CMockOutputStream()) - self.is_readable = 0 - self.is_writeable = 1 - self.is_open = True + self.is_writable = True + self.closed = False def size(self): return (self.wr_file.get()).GetExtentBytesWritten() @@ -779,9 +793,8 @@ cdef class BufferReader(NativeFile): self.buffer = frombuffer(obj) self.rd_file.reset(new CBufferReader(self.buffer.buffer)) - self.is_readable = 1 - self.is_writeable = 0 - self.is_open = True + self.is_readable = True + self.closed = False def frombuffer(object obj): @@ -833,8 +846,8 @@ cdef get_writer(object source, shared_ptr[OutputStream]* writer): if isinstance(source, NativeFile): nf = source - if not nf.is_writeable: - raise IOError('Native file is not writeable') + if not nf.is_writable: + raise IOError('Native file is not writable') nf.write_handle(writer) else: diff --git a/python/pyarrow/ipc.pxi b/python/pyarrow/ipc.pxi index 7534b0d0e87ec..a30a228ae878f 100644 --- a/python/pyarrow/ipc.pxi +++ b/python/pyarrow/ipc.pxi @@ -429,7 +429,7 @@ def write_tensor(Tensor tensor, NativeFile dest): int32_t metadata_length int64_t body_length - dest._assert_writeable() + dest._assert_writable() with nogil: check_status( diff --git a/python/pyarrow/ipc.py b/python/pyarrow/ipc.py index f264f089c4071..4081fc50e6df6 100644 --- a/python/pyarrow/ipc.py +++ b/python/pyarrow/ipc.py @@ -65,7 +65,7 @@ class RecordBatchStreamWriter(lib._RecordBatchWriter): Parameters ---------- sink : str, pyarrow.NativeFile, or file-like Python object - Either a file path, or a writeable file object + Either a file path, or a writable file object schema : pyarrow.Schema The Arrow schema for data to be written to the file """ @@ -96,7 +96,7 @@ class RecordBatchFileWriter(lib._RecordBatchFileWriter): Parameters ---------- sink : str, pyarrow.NativeFile, or file-like Python object - Either a file path, or a writeable file object + Either a file path, or a writable file object schema : pyarrow.Schema The Arrow schema for data to be written to the file """ diff --git a/python/pyarrow/lib.pxd b/python/pyarrow/lib.pxd index 90f749d6db633..161562c040c30 100644 --- a/python/pyarrow/lib.pxd +++ b/python/pyarrow/lib.pxd @@ -333,8 +333,8 @@ cdef class NativeFile: shared_ptr[RandomAccessFile] rd_file shared_ptr[OutputStream] wr_file bint is_readable - bint is_writeable - bint is_open + bint is_writable + readonly bint closed bint own_file # By implementing these "virtual" functions (all functions in Cython diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py index 3af1f42990b0b..240cccdaffe56 100644 --- a/python/pyarrow/pandas_compat.py +++ b/python/pyarrow/pandas_compat.py @@ -452,13 +452,12 @@ def _reconstruct_block(item): categories=item['dictionary'], ordered=item['ordered']) block = _int.make_block(cat, placement=placement, - klass=_int.CategoricalBlock, - fastpath=True) + klass=_int.CategoricalBlock) elif 'timezone' in item: dtype = _make_datetimetz(item['timezone']) block = _int.make_block(block_arr, placement=placement, klass=_int.DatetimeTZBlock, - dtype=dtype, fastpath=True) + dtype=dtype) else: block = _int.make_block(block_arr, placement=placement) diff --git a/python/pyarrow/parquet.py b/python/pyarrow/parquet.py index 151e0df8a22d0..3a0924a27ceb2 100644 --- a/python/pyarrow/parquet.py +++ b/python/pyarrow/parquet.py @@ -215,7 +215,9 @@ def _sanitize_schema(schema, flavor): sanitized_fields.append(sanitized_field) else: sanitized_fields.append(field) - return pa.schema(sanitized_fields), schema_changed + + new_schema = pa.schema(sanitized_fields, metadata=schema.metadata) + return new_schema, schema_changed else: return schema, False diff --git a/python/pyarrow/plasma.pyx b/python/pyarrow/plasma.pyx index 29e233b6e4e67..32f6d189da08c 100644 --- a/python/pyarrow/plasma.pyx +++ b/python/pyarrow/plasma.pyx @@ -248,8 +248,8 @@ cdef class PlasmaClient: check_status(self.client.get().Get(ids.data(), ids.size(), timeout_ms, result[0].data())) - cdef _make_plasma_buffer(self, ObjectID object_id, shared_ptr[CBuffer] buffer, - int64_t size): + cdef _make_plasma_buffer(self, ObjectID object_id, + shared_ptr[CBuffer] buffer, int64_t size): result = PlasmaBuffer(object_id, self) result.init(buffer) return result @@ -302,7 +302,9 @@ cdef class PlasmaClient: check_status(self.client.get().Create(object_id.data, data_size, (metadata.data()), metadata.size(), &data)) - return self._make_mutable_plasma_buffer(object_id, data.get().mutable_data(), data_size) + return self._make_mutable_plasma_buffer(object_id, + data.get().mutable_data(), + data_size) def get_buffers(self, object_ids, timeout_ms=-1): """ diff --git a/python/pyarrow/serialization.pxi b/python/pyarrow/serialization.pxi index d95d582fe537e..e7a39905f1f65 100644 --- a/python/pyarrow/serialization.pxi +++ b/python/pyarrow/serialization.pxi @@ -50,6 +50,8 @@ cdef class SerializationContext: object types_to_pickle object custom_serializers object custom_deserializers + object pickle_serializer + object pickle_deserializer def __init__(self): # Types with special serialization handlers @@ -58,6 +60,23 @@ cdef class SerializationContext: self.types_to_pickle = set() self.custom_serializers = dict() self.custom_deserializers = dict() + self.pickle_serializer = pickle.dumps + self.pickle_deserializer = pickle.loads + + def set_pickle(self, serializer, deserializer): + """ + Set the serializer and deserializer to use for objects that are to be + pickled. + + Parameters + ---------- + serializer : callable + The serializer to use (e.g., pickle.dumps or cloudpickle.dumps). + deserializer : callable + The deserializer to use (e.g., pickle.dumps or cloudpickle.dumps). + """ + self.pickle_serializer = serializer + self.pickle_deserializer = deserializer def clone(self): """ @@ -72,6 +91,8 @@ cdef class SerializationContext: result.whitelisted_types = self.whitelisted_types.copy() result.custom_serializers = self.custom_serializers.copy() result.custom_deserializers = self.custom_deserializers.copy() + result.pickle_serializer = self.pickle_serializer + result.pickle_deserializer = self.pickle_deserializer return result @@ -119,7 +140,8 @@ cdef class SerializationContext: # use the closest match to type(obj) type_id = self.type_to_type_id[type_] if type_id in self.types_to_pickle: - serialized_obj = {"data": pickle.dumps(obj), "pickle": True} + serialized_obj = {"data": self.pickle_serializer(obj), + "pickle": True} elif type_id in self.custom_serializers: serialized_obj = {"data": self.custom_serializers[type_id](obj)} else: @@ -139,7 +161,7 @@ cdef class SerializationContext: if "pickle" in serialized_obj: # The object was pickled, so unpickle it. - obj = pickle.loads(serialized_obj["data"]) + obj = self.pickle_deserializer(serialized_obj["data"]) else: assert type_id not in self.types_to_pickle if type_id not in self.whitelisted_types: diff --git a/python/pyarrow/serialization.py b/python/pyarrow/serialization.py index 61f2e83f3193d..c8b72b74896c9 100644 --- a/python/pyarrow/serialization.py +++ b/python/pyarrow/serialization.py @@ -22,7 +22,8 @@ import numpy as np from pyarrow.compat import builtin_pickle -from pyarrow.lib import _default_serialization_context, frombuffer +from pyarrow.lib import (SerializationContext, _default_serialization_context, + frombuffer) try: import cloudpickle @@ -102,6 +103,31 @@ def _deserialize_pandas_series(data): custom_deserializer=_deserialize_pandas_dataframe) +def register_torch_serialization_handlers(serialization_context): + # ---------------------------------------------------------------------- + # Set up serialization for pytorch tensors + + try: + import torch + + def _serialize_torch_tensor(obj): + return obj.numpy() + + def _deserialize_torch_tensor(data): + return torch.from_numpy(data) + + for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, + torch.ByteTensor, torch.CharTensor, torch.ShortTensor, + torch.IntTensor, torch.LongTensor]: + serialization_context.register_type( + t, "torch." + t.__name__, + custom_serializer=_serialize_torch_tensor, + custom_deserializer=_deserialize_torch_tensor) + except ImportError: + # no torch + pass + + def register_default_serialization_handlers(serialization_context): # ---------------------------------------------------------------------- @@ -154,37 +180,21 @@ def _deserialize_default_dict(data): custom_serializer=_serialize_numpy_array_list, custom_deserializer=_deserialize_numpy_array_list) - # ---------------------------------------------------------------------- - # Set up serialization for pytorch tensors - - try: - import torch - - def _serialize_torch_tensor(obj): - return obj.numpy() + _register_custom_pandas_handlers(serialization_context) - def _deserialize_torch_tensor(data): - return torch.from_numpy(data) - for t in [torch.FloatTensor, torch.DoubleTensor, torch.HalfTensor, - torch.ByteTensor, torch.CharTensor, torch.ShortTensor, - torch.IntTensor, torch.LongTensor]: - serialization_context.register_type( - t, "torch." + t.__name__, - custom_serializer=_serialize_torch_tensor, - custom_deserializer=_deserialize_torch_tensor) - except ImportError: - # no torch - pass - - _register_custom_pandas_handlers(serialization_context) +def default_serialization_context(): + context = SerializationContext() + register_default_serialization_handlers(context) + return context register_default_serialization_handlers(_default_serialization_context) -pandas_serialization_context = _default_serialization_context.clone() -pandas_serialization_context.register_type( - np.ndarray, 'np.array', - custom_serializer=_serialize_numpy_array_pickle, - custom_deserializer=_deserialize_numpy_array_pickle) +def pandas_serialization_context(): + context = default_serialization_context() + context.register_type(np.ndarray, 'np.array', + custom_serializer=_serialize_numpy_array_pickle, + custom_deserializer=_deserialize_numpy_array_pickle) + return context diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py index fa38c9257854e..1d5d30071902a 100644 --- a/python/pyarrow/tests/test_array.py +++ b/python/pyarrow/tests/test_array.py @@ -485,6 +485,12 @@ def test_logical_type(type, expected): assert get_logical_type(type) == expected +def test_array_uint64_from_py_over_range(): + arr = pa.array([2 ** 63], type=pa.uint64()) + expected = pa.array(np.array([2 ** 63], dtype='u8')) + assert arr.equals(expected) + + def test_array_conversions_no_sentinel_values(): arr = np.array([1, 2, 3, 4], dtype='int8') refcount = sys.getrefcount(arr) @@ -507,6 +513,23 @@ def test_array_from_numpy_datetimeD(): assert result.equals(expected) +def test_array_from_py_float32(): + data = [[1.2, 3.4], [9.0, 42.0]] + + t = pa.float32() + + arr1 = pa.array(data[0], type=t) + arr2 = pa.array(data, type=pa.list_(t)) + + expected1 = np.array(data[0], dtype=np.float32) + expected2 = pd.Series([np.array(data[0], dtype=np.float32), + np.array(data[1], dtype=np.float32)]) + + assert arr1.type == t + assert arr1.equals(pa.array(expected1)) + assert arr2.equals(pa.array(expected2)) + + def test_array_from_numpy_ascii(): arr = np.array(['abcde', 'abc', ''], dtype='|S5') diff --git a/python/pyarrow/tests/test_convert_builtin.py b/python/pyarrow/tests/test_convert_builtin.py index d7760da2f9b47..bbdf6e71e0f1d 100644 --- a/python/pyarrow/tests/test_convert_builtin.py +++ b/python/pyarrow/tests/test_convert_builtin.py @@ -23,6 +23,9 @@ import datetime import decimal +import itertools +import numpy as np +import six class StrangeIterable: @@ -33,356 +36,496 @@ def __iter__(self): return self.lst.__iter__() -class TestConvertIterable(unittest.TestCase): - - def test_iterable_types(self): - arr1 = pa.array(StrangeIterable([0, 1, 2, 3])) - arr2 = pa.array((0, 1, 2, 3)) - - assert arr1.equals(arr2) - - def test_empty_iterable(self): - arr = pa.array(StrangeIterable([])) - assert len(arr) == 0 - assert arr.null_count == 0 - assert arr.type == pa.null() - assert arr.to_pylist() == [] - - -class TestLimitedConvertIterator(unittest.TestCase): - def test_iterator_types(self): - arr1 = pa.array(iter(range(3)), type=pa.int64(), size=3) - arr2 = pa.array((0, 1, 2)) - assert arr1.equals(arr2) - - def test_iterator_size_overflow(self): - arr1 = pa.array(iter(range(3)), type=pa.int64(), size=2) - arr2 = pa.array((0, 1)) - assert arr1.equals(arr2) - - def test_iterator_size_underflow(self): - arr1 = pa.array(iter(range(3)), type=pa.int64(), size=10) - arr2 = pa.array((0, 1, 2)) - assert arr1.equals(arr2) - - -class TestConvertSequence(unittest.TestCase): - - def test_sequence_types(self): - arr1 = pa.array([1, 2, 3]) - arr2 = pa.array((1, 2, 3)) - - assert arr1.equals(arr2) - - def test_boolean(self): - expected = [True, None, False, None] - arr = pa.array(expected) - assert len(arr) == 4 - assert arr.null_count == 2 - assert arr.type == pa.bool_() - assert arr.to_pylist() == expected - - def test_empty_list(self): - arr = pa.array([]) - assert len(arr) == 0 - assert arr.null_count == 0 - assert arr.type == pa.null() - assert arr.to_pylist() == [] - - def test_all_none(self): - arr = pa.array([None, None]) - assert len(arr) == 2 - assert arr.null_count == 2 - assert arr.type == pa.null() - assert arr.to_pylist() == [None, None] - - def test_integer(self): - expected = [1, None, 3, None] - arr = pa.array(expected) - assert len(arr) == 4 - assert arr.null_count == 2 - assert arr.type == pa.int64() - assert arr.to_pylist() == expected - - def test_garbage_collection(self): - import gc - - # Force the cyclic garbage collector to run - gc.collect() - - bytes_before = pa.total_allocated_bytes() - pa.array([1, None, 3, None]) - gc.collect() - assert pa.total_allocated_bytes() == bytes_before - - def test_double(self): - data = [1.5, 1, None, 2.5, None, None] - arr = pa.array(data) - assert len(arr) == 6 - assert arr.null_count == 3 - assert arr.type == pa.float64() - assert arr.to_pylist() == data - - def test_unicode(self): - data = [u'foo', u'bar', None, u'mañana'] - arr = pa.array(data) - assert len(arr) == 4 - assert arr.null_count == 1 - assert arr.type == pa.string() - assert arr.to_pylist() == data - - def test_bytes(self): - u1 = b'ma\xc3\xb1ana' - data = [b'foo', - u1.decode('utf-8'), # unicode gets encoded, - None] - arr = pa.array(data) - assert len(arr) == 3 - assert arr.null_count == 1 - assert arr.type == pa.binary() - assert arr.to_pylist() == [b'foo', u1, None] - - def test_utf8_to_unicode(self): - # ARROW-1225 - data = [b'foo', None, b'bar'] - arr = pa.array(data, type=pa.string()) - assert arr[0].as_py() == u'foo' - - # test a non-utf8 unicode string - val = (u'mañana').encode('utf-16-le') - with pytest.raises(pa.ArrowException): - pa.array([val], type=pa.string()) - - def test_fixed_size_bytes(self): - data = [b'foof', None, b'barb', b'2346'] - arr = pa.array(data, type=pa.binary(4)) - assert len(arr) == 4 - assert arr.null_count == 1 - assert arr.type == pa.binary(4) - assert arr.to_pylist() == data - - def test_fixed_size_bytes_does_not_accept_varying_lengths(self): - data = [b'foo', None, b'barb', b'2346'] - with self.assertRaises(pa.ArrowInvalid): - pa.array(data, type=pa.binary(4)) - - def test_date(self): - data = [datetime.date(2000, 1, 1), None, datetime.date(1970, 1, 1), - datetime.date(2040, 2, 26)] - arr = pa.array(data) - assert len(arr) == 4 - assert arr.type == pa.date64() - assert arr.null_count == 1 - assert arr[0].as_py() == datetime.date(2000, 1, 1) - assert arr[1].as_py() is None - assert arr[2].as_py() == datetime.date(1970, 1, 1) - assert arr[3].as_py() == datetime.date(2040, 2, 26) - - def test_date32(self): - data = [datetime.date(2000, 1, 1), None] - arr = pa.array(data, type=pa.date32()) - - data2 = [10957, None] - arr2 = pa.array(data2, type=pa.date32()) - - for x in [arr, arr2]: - assert len(x) == 2 - assert x.type == pa.date32() - assert x.null_count == 1 - assert x[0].as_py() == datetime.date(2000, 1, 1) - assert x[1] is pa.NA - - # Overflow - data3 = [2**32, None] - with pytest.raises(pa.ArrowException): - pa.array(data3, type=pa.date32()) - - def test_timestamp(self): - data = [ - datetime.datetime(2007, 7, 13, 1, 23, 34, 123456), +def test_iterable_types(): + arr1 = pa.array(StrangeIterable([0, 1, 2, 3])) + arr2 = pa.array((0, 1, 2, 3)) + + assert arr1.equals(arr2) + + +def test_empty_iterable(): + arr = pa.array(StrangeIterable([])) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +def test_limited_iterator_types(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=3) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_limited_iterator_size_overflow(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=2) + arr2 = pa.array((0, 1)) + assert arr1.equals(arr2) + + +def test_limited_iterator_size_underflow(): + arr1 = pa.array(iter(range(3)), type=pa.int64(), size=10) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_iterator_without_size(): + expected = pa.array((0, 1, 2)) + arr1 = pa.array(iter(range(3))) + assert arr1.equals(expected) + # Same with explicit type + arr1 = pa.array(iter(range(3)), type=pa.int64()) + assert arr1.equals(expected) + + +def test_infinite_iterator(): + expected = pa.array((0, 1, 2)) + arr1 = pa.array(itertools.count(0), size=3) + assert arr1.equals(expected) + # Same with explicit type + arr1 = pa.array(itertools.count(0), type=pa.int64(), size=3) + assert arr1.equals(expected) + + +def _as_list(xs): + return xs + + +def _as_tuple(xs): + return tuple(xs) + + +def _as_dict_values(xs): + dct = {k: v for k, v in enumerate(xs)} + return six.viewvalues(dct) + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +def test_sequence_types(seq): + arr1 = pa.array(seq([1, 2, 3])) + arr2 = pa.array([1, 2, 3]) + + assert arr1.equals(arr2) + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +def test_sequence_boolean(seq): + expected = [True, None, False, None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.bool_() + assert arr.to_pylist() == expected + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +def test_sequence_numpy_boolean(seq): + expected = [np.bool(True), None, np.bool(False), None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.bool_() + assert arr.to_pylist() == expected + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +def test_empty_list(seq): + arr = pa.array(seq([])) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +def test_sequence_all_none(): + arr = pa.array([None, None]) + assert len(arr) == 2 + assert arr.null_count == 2 + assert arr.type == pa.null() + assert arr.to_pylist() == [None, None] + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +def test_sequence_integer(seq): + expected = [1, None, 3, None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.int64() + assert arr.to_pylist() == expected + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +@pytest.mark.parametrize("np_scalar", [np.int16, np.int32, np.int64, np.uint16, + np.uint32, np.uint64]) +def test_sequence_numpy_integer(seq, np_scalar): + expected = [np_scalar(1), None, np_scalar(3), None] + arr = pa.array(seq(expected)) + assert len(arr) == 4 + assert arr.null_count == 2 + assert arr.type == pa.int64() + assert arr.to_pylist() == expected + + +def test_garbage_collection(): + import gc + + # Force the cyclic garbage collector to run + gc.collect() + + bytes_before = pa.total_allocated_bytes() + pa.array([1, None, 3, None]) + gc.collect() + assert pa.total_allocated_bytes() == bytes_before + + +def test_sequence_double(): + data = [1.5, 1, None, 2.5, None, None] + arr = pa.array(data) + assert len(arr) == 6 + assert arr.null_count == 3 + assert arr.type == pa.float64() + assert arr.to_pylist() == data + + +@pytest.mark.parametrize("seq", [_as_list, _as_tuple, _as_dict_values]) +@pytest.mark.parametrize("np_scalar", [np.float16, np.float32, np.float64]) +def test_sequence_numpy_double(seq, np_scalar): + data = [np_scalar(1.5), np_scalar(1), None, np_scalar(2.5), None, None] + arr = pa.array(seq(data)) + assert len(arr) == 6 + assert arr.null_count == 3 + assert arr.type == pa.float64() + assert arr.to_pylist() == data + + +def test_sequence_unicode(): + data = [u'foo', u'bar', None, u'mañana'] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.string() + assert arr.to_pylist() == data + + +def test_sequence_bytes(): + u1 = b'ma\xc3\xb1ana' + data = [b'foo', + u1.decode('utf-8'), # unicode gets encoded, + None] + arr = pa.array(data) + assert len(arr) == 3 + assert arr.null_count == 1 + assert arr.type == pa.binary() + assert arr.to_pylist() == [b'foo', u1, None] + + +def test_sequence_utf8_to_unicode(): + # ARROW-1225 + data = [b'foo', None, b'bar'] + arr = pa.array(data, type=pa.string()) + assert arr[0].as_py() == u'foo' + + # test a non-utf8 unicode string + val = (u'mañana').encode('utf-16-le') + with pytest.raises(pa.ArrowException): + pa.array([val], type=pa.string()) + + +def test_sequence_fixed_size_bytes(): + data = [b'foof', None, b'barb', b'2346'] + arr = pa.array(data, type=pa.binary(4)) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.binary(4) + assert arr.to_pylist() == data + + +def test_fixed_size_bytes_does_not_accept_varying_lengths(): + data = [b'foo', None, b'barb', b'2346'] + with pytest.raises(pa.ArrowInvalid): + pa.array(data, type=pa.binary(4)) + + +def test_sequence_date(): + data = [datetime.date(2000, 1, 1), None, datetime.date(1970, 1, 1), + datetime.date(2040, 2, 26)] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.date64() + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.date(2000, 1, 1) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.date(1970, 1, 1) + assert arr[3].as_py() == datetime.date(2040, 2, 26) + + +def test_sequence_date32(): + data = [datetime.date(2000, 1, 1), None] + arr = pa.array(data, type=pa.date32()) + + data2 = [10957, None] + arr2 = pa.array(data2, type=pa.date32()) + + for x in [arr, arr2]: + assert len(x) == 2 + assert x.type == pa.date32() + assert x.null_count == 1 + assert x[0].as_py() == datetime.date(2000, 1, 1) + assert x[1] is pa.NA + + # Overflow + data3 = [2**32, None] + with pytest.raises(pa.ArrowException): + pa.array(data3, type=pa.date32()) + + +def test_sequence_timestamp(): + data = [ + datetime.datetime(2007, 7, 13, 1, 23, 34, 123456), + None, + datetime.datetime(2006, 1, 13, 12, 34, 56, 432539), + datetime.datetime(2010, 8, 13, 5, 46, 57, 437699) + ] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.timestamp('us') + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12, + 34, 56, 432539) + assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5, + 46, 57, 437699) + + +def test_sequence_numpy_timestamp(): + data = [ + np.datetime64(datetime.datetime(2007, 7, 13, 1, 23, 34, 123456)), + None, + np.datetime64(datetime.datetime(2006, 1, 13, 12, 34, 56, 432539)), + np.datetime64(datetime.datetime(2010, 8, 13, 5, 46, 57, 437699)) + ] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.type == pa.timestamp('us') + assert arr.null_count == 1 + assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + assert arr[1].as_py() is None + assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12, + 34, 56, 432539) + assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5, + 46, 57, 437699) + + +def test_sequence_timestamp_with_unit(): + data = [ + datetime.datetime(2007, 7, 13, 1, 23, 34, 123456), + ] + + s = pa.timestamp('s') + ms = pa.timestamp('ms') + us = pa.timestamp('us') + ns = pa.timestamp('ns') + + arr_s = pa.array(data, type=s) + assert len(arr_s) == 1 + assert arr_s.type == s + assert arr_s[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 0) + + arr_ms = pa.array(data, type=ms) + assert len(arr_ms) == 1 + assert arr_ms.type == ms + assert arr_ms[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123000) + + arr_us = pa.array(data, type=us) + assert len(arr_us) == 1 + assert arr_us.type == us + assert arr_us[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + + arr_ns = pa.array(data, type=ns) + assert len(arr_ns) == 1 + assert arr_ns.type == ns + assert arr_ns[0].as_py() == datetime.datetime(2007, 7, 13, 1, + 23, 34, 123456) + + +def test_sequence_timestamp_from_int_with_unit(): + data = [1] + + s = pa.timestamp('s') + ms = pa.timestamp('ms') + us = pa.timestamp('us') + ns = pa.timestamp('ns') + + arr_s = pa.array(data, type=s) + assert len(arr_s) == 1 + assert arr_s.type == s + assert str(arr_s[0]) == "Timestamp('1970-01-01 00:00:01')" + + arr_ms = pa.array(data, type=ms) + assert len(arr_ms) == 1 + assert arr_ms.type == ms + assert str(arr_ms[0]) == "Timestamp('1970-01-01 00:00:00.001000')" + + arr_us = pa.array(data, type=us) + assert len(arr_us) == 1 + assert arr_us.type == us + assert str(arr_us[0]) == "Timestamp('1970-01-01 00:00:00.000001')" + + arr_ns = pa.array(data, type=ns) + assert len(arr_ns) == 1 + assert arr_ns.type == ns + assert str(arr_ns[0]) == "Timestamp('1970-01-01 00:00:00.000000001')" + + with pytest.raises(pa.ArrowException): + class CustomClass(): + pass + pa.array([1, CustomClass()], type=ns) + pa.array([1, CustomClass()], type=pa.date32()) + pa.array([1, CustomClass()], type=pa.date64()) + + +def test_sequence_mixed_nesting_levels(): + pa.array([1, 2, None]) + pa.array([[1], [2], None]) + pa.array([[1], [2], [None]]) + + with pytest.raises(pa.ArrowInvalid): + pa.array([1, 2, [1]]) + + with pytest.raises(pa.ArrowInvalid): + pa.array([1, 2, []]) + + with pytest.raises(pa.ArrowInvalid): + pa.array([[1], [2], [None, [1]]]) + + +def test_sequence_list_of_int(): + data = [[1, 2, 3], [], None, [1, 2]] + arr = pa.array(data) + assert len(arr) == 4 + assert arr.null_count == 1 + assert arr.type == pa.list_(pa.int64()) + assert arr.to_pylist() == data + + +def test_sequence_mixed_types_fails(): + data = ['a', 1, 2.0] + with pytest.raises(pa.ArrowException): + pa.array(data) + + +def test_sequence_mixed_types_with_specified_type_fails(): + data = ['-10', '-5', {'a': 1}, '0', '5', '10'] + + type = pa.string() + with pytest.raises(pa.ArrowInvalid): + pa.array(data, type=type) + + +def test_sequence_decimal(): + data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')] + type = pa.decimal128(precision=7, scale=3) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_sequence_decimal_different_precisions(): + data = [ + decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234') + ] + type = pa.decimal128(precision=13, scale=3) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_sequence_decimal_no_scale(): + data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')] + type = pa.decimal128(precision=10) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_sequence_decimal_negative(): + data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')] + type = pa.decimal128(precision=10, scale=6) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_sequence_decimal_no_whole_part(): + data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')] + type = pa.decimal128(precision=7, scale=7) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_sequence_decimal_large_integer(): + data = [decimal.Decimal('-394029506937548693.42983'), + decimal.Decimal('32358695912932.01033')] + type = pa.decimal128(precision=23, scale=5) + arr = pa.array(data, type=type) + assert arr.to_pylist() == data + + +def test_range_types(): + arr1 = pa.array(range(3)) + arr2 = pa.array((0, 1, 2)) + assert arr1.equals(arr2) + + +def test_empty_range(): + arr = pa.array(range(0)) + assert len(arr) == 0 + assert arr.null_count == 0 + assert arr.type == pa.null() + assert arr.to_pylist() == [] + + +def test_structarray(): + ints = pa.array([None, 2, 3], type=pa.int64()) + strs = pa.array([u'a', None, u'c'], type=pa.string()) + bools = pa.array([True, False, None], type=pa.bool_()) + arr = pa.StructArray.from_arrays( + ['ints', 'strs', 'bools'], + [ints, strs, bools]) + + expected = [ + {'ints': None, 'strs': u'a', 'bools': True}, + {'ints': 2, 'strs': None, 'bools': False}, + {'ints': 3, 'strs': u'c', 'bools': None}, + ] + + pylist = arr.to_pylist() + assert pylist == expected, (pylist, expected) + + +def test_struct_from_dicts(): + ty = pa.struct([pa.field('a', pa.int32()), + pa.field('b', pa.string()), + pa.field('c', pa.bool_())]) + arr = pa.array([], type=ty) + assert arr.to_pylist() == [] + + data = [{'a': 5, 'b': 'foo', 'c': True}, + {'a': 6, 'b': 'bar', 'c': False}] + arr = pa.array(data, type=ty) + assert arr.to_pylist() == data + + # With omitted values + data = [{'a': 5, 'c': True}, None, - datetime.datetime(2006, 1, 13, 12, 34, 56, 432539), - datetime.datetime(2010, 8, 13, 5, 46, 57, 437699) - ] - arr = pa.array(data) - assert len(arr) == 4 - assert arr.type == pa.timestamp('us') - assert arr.null_count == 1 - assert arr[0].as_py() == datetime.datetime(2007, 7, 13, 1, - 23, 34, 123456) - assert arr[1].as_py() is None - assert arr[2].as_py() == datetime.datetime(2006, 1, 13, 12, - 34, 56, 432539) - assert arr[3].as_py() == datetime.datetime(2010, 8, 13, 5, - 46, 57, 437699) - - def test_timestamp_with_unit(self): - data = [ - datetime.datetime(2007, 7, 13, 1, 23, 34, 123456), - ] - - s = pa.timestamp('s') - ms = pa.timestamp('ms') - us = pa.timestamp('us') - ns = pa.timestamp('ns') - - arr_s = pa.array(data, type=s) - assert len(arr_s) == 1 - assert arr_s.type == s - assert arr_s[0].as_py() == datetime.datetime(2007, 7, 13, 1, - 23, 34, 0) - - arr_ms = pa.array(data, type=ms) - assert len(arr_ms) == 1 - assert arr_ms.type == ms - assert arr_ms[0].as_py() == datetime.datetime(2007, 7, 13, 1, - 23, 34, 123000) - - arr_us = pa.array(data, type=us) - assert len(arr_us) == 1 - assert arr_us.type == us - assert arr_us[0].as_py() == datetime.datetime(2007, 7, 13, 1, - 23, 34, 123456) - - arr_ns = pa.array(data, type=ns) - assert len(arr_ns) == 1 - assert arr_ns.type == ns - assert arr_ns[0].as_py() == datetime.datetime(2007, 7, 13, 1, - 23, 34, 123456) - - def test_timestamp_from_int_with_unit(self): - data = [1] - - s = pa.timestamp('s') - ms = pa.timestamp('ms') - us = pa.timestamp('us') - ns = pa.timestamp('ns') - - arr_s = pa.array(data, type=s) - assert len(arr_s) == 1 - assert arr_s.type == s - assert str(arr_s[0]) == "Timestamp('1970-01-01 00:00:01')" - - arr_ms = pa.array(data, type=ms) - assert len(arr_ms) == 1 - assert arr_ms.type == ms - assert str(arr_ms[0]) == "Timestamp('1970-01-01 00:00:00.001000')" - - arr_us = pa.array(data, type=us) - assert len(arr_us) == 1 - assert arr_us.type == us - assert str(arr_us[0]) == "Timestamp('1970-01-01 00:00:00.000001')" - - arr_ns = pa.array(data, type=ns) - assert len(arr_ns) == 1 - assert arr_ns.type == ns - assert str(arr_ns[0]) == "Timestamp('1970-01-01 00:00:00.000000001')" - - with pytest.raises(pa.ArrowException): - class CustomClass(): - pass - pa.array([1, CustomClass()], type=ns) - pa.array([1, CustomClass()], type=pa.date32()) - pa.array([1, CustomClass()], type=pa.date64()) - - def test_mixed_nesting_levels(self): - pa.array([1, 2, None]) - pa.array([[1], [2], None]) - pa.array([[1], [2], [None]]) - - with self.assertRaises(pa.ArrowInvalid): - pa.array([1, 2, [1]]) - - with self.assertRaises(pa.ArrowInvalid): - pa.array([1, 2, []]) - - with self.assertRaises(pa.ArrowInvalid): - pa.array([[1], [2], [None, [1]]]) - - def test_list_of_int(self): - data = [[1, 2, 3], [], None, [1, 2]] - arr = pa.array(data) - assert len(arr) == 4 - assert arr.null_count == 1 - assert arr.type == pa.list_(pa.int64()) - assert arr.to_pylist() == data - - def test_mixed_types_fails(self): - data = ['a', 1, 2.0] - with self.assertRaises(pa.ArrowException): - pa.array(data) - - def test_mixed_types_with_specified_type_fails(self): - data = ['-10', '-5', {'a': 1}, '0', '5', '10'] - - type = pa.string() - with self.assertRaises(pa.ArrowInvalid): - pa.array(data, type=type) - - def test_decimal(self): - data = [decimal.Decimal('1234.183'), decimal.Decimal('8094.234')] - type = pa.decimal128(precision=7, scale=3) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_decimal_different_precisions(self): - data = [ - decimal.Decimal('1234234983.183'), decimal.Decimal('80943244.234') - ] - type = pa.decimal128(precision=13, scale=3) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_decimal_no_scale(self): - data = [decimal.Decimal('1234234983'), decimal.Decimal('8094324')] - type = pa.decimal128(precision=10) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_decimal_negative(self): - data = [decimal.Decimal('-1234.234983'), decimal.Decimal('-8.094324')] - type = pa.decimal128(precision=10, scale=6) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_decimal_no_whole_part(self): - data = [decimal.Decimal('-.4234983'), decimal.Decimal('.0103943')] - type = pa.decimal128(precision=7, scale=7) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_decimal_large_integer(self): - data = [decimal.Decimal('-394029506937548693.42983'), - decimal.Decimal('32358695912932.01033')] - type = pa.decimal128(precision=23, scale=5) - arr = pa.array(data, type=type) - assert arr.to_pylist() == data - - def test_range_types(self): - arr1 = pa.array(range(3)) - arr2 = pa.array((0, 1, 2)) - assert arr1.equals(arr2) - - def test_empty_range(self): - arr = pa.array(range(0)) - assert len(arr) == 0 - assert arr.null_count == 0 - assert arr.type == pa.null() - assert arr.to_pylist() == [] - - def test_structarray(self): - ints = pa.array([None, 2, 3], type=pa.int64()) - strs = pa.array([u'a', None, u'c'], type=pa.string()) - bools = pa.array([True, False, None], type=pa.bool_()) - arr = pa.StructArray.from_arrays( - ['ints', 'strs', 'bools'], - [ints, strs, bools]) - - expected = [ - {'ints': None, 'strs': u'a', 'bools': True}, - {'ints': 2, 'strs': None, 'bools': False}, - {'ints': 3, 'strs': u'c', 'bools': None}, - ] - - pylist = arr.to_pylist() - assert pylist == expected, (pylist, expected) + {}, + {'a': None, 'b': 'bar'}] + arr = pa.array(data, type=ty) + expected = [{'a': 5, 'b': None, 'c': True}, + None, + {'a': None, 'b': None, 'c': None}, + {'a': None, 'b': 'bar', 'c': None}] + assert arr.to_pylist() == expected diff --git a/python/pyarrow/tests/test_convert_pandas.py b/python/pyarrow/tests/test_convert_pandas.py index df8d982209c46..f1f40a695edc1 100644 --- a/python/pyarrow/tests/test_convert_pandas.py +++ b/python/pyarrow/tests/test_convert_pandas.py @@ -1238,7 +1238,37 @@ def test_decimal_metadata(self): assert data_column['numpy_type'] == 'object' assert data_column['metadata'] == {'precision': 26, 'scale': 11} - def test_table_str_to_categorical(self): + def test_table_empty_str(self): + values = ['', '', '', '', ''] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + + result1 = table.to_pandas(strings_to_categorical=False) + expected1 = pd.DataFrame({'strings': values}) + tm.assert_frame_equal(result1, expected1, check_dtype=True) + + result2 = table.to_pandas(strings_to_categorical=True) + expected2 = pd.DataFrame({'strings': pd.Categorical(values)}) + tm.assert_frame_equal(result2, expected2, check_dtype=True) + + def test_table_str_to_categorical_without_na(self): + values = ['a', 'a', 'b', 'b', 'c'] + df = pd.DataFrame({'strings': values}) + field = pa.field('strings', pa.string()) + schema = pa.schema([field]) + table = pa.Table.from_pandas(df, schema=schema) + + result = table.to_pandas(strings_to_categorical=True) + expected = pd.DataFrame({'strings': pd.Categorical(values)}) + tm.assert_frame_equal(result, expected, check_dtype=True) + + with pytest.raises(pa.ArrowInvalid): + table.to_pandas(strings_to_categorical=True, + zero_copy_only=True) + + def test_table_str_to_categorical_with_na(self): values = [None, 'a', 'b', np.nan] df = pd.DataFrame({'strings': values}) field = pa.field('strings', pa.string()) @@ -1249,6 +1279,10 @@ def test_table_str_to_categorical(self): expected = pd.DataFrame({'strings': pd.Categorical(values)}) tm.assert_frame_equal(result, expected, check_dtype=True) + with pytest.raises(pa.ArrowInvalid): + table.to_pandas(strings_to_categorical=True, + zero_copy_only=True) + def test_table_batch_empty_dataframe(self): df = pd.DataFrame({}) _check_pandas_roundtrip(df) @@ -1371,7 +1405,7 @@ def _fully_loaded_dataframe_example(): def _check_serialize_components_roundtrip(df): - ctx = pa.pandas_serialization_context + ctx = pa.pandas_serialization_context() components = ctx.serialize(df).to_components() deserialized = ctx.deserialize_components(components) diff --git a/python/pyarrow/tests/test_hdfs.py b/python/pyarrow/tests/test_hdfs.py index 51b6ba25bd657..b62458cd73689 100644 --- a/python/pyarrow/tests/test_hdfs.py +++ b/python/pyarrow/tests/test_hdfs.py @@ -18,6 +18,7 @@ from io import BytesIO from os.path import join as pjoin import os +import pickle import random import unittest @@ -36,7 +37,7 @@ def hdfs_test_client(driver='libhdfs'): host = os.environ.get('ARROW_HDFS_TEST_HOST', 'localhost') - user = os.environ['ARROW_HDFS_TEST_USER'] + user = os.environ.get('ARROW_HDFS_TEST_USER', None) try: port = int(os.environ.get('ARROW_HDFS_TEST_PORT', 20500)) except ValueError: @@ -72,6 +73,22 @@ def tearDownClass(cls): cls.hdfs.delete(cls.tmp_path, recursive=True) cls.hdfs.close() + def test_unknown_driver(self): + with pytest.raises(ValueError): + hdfs_test_client(driver="not_a_driver_name") + + def test_pickle(self): + s = pickle.dumps(self.hdfs) + h2 = pickle.loads(s) + assert h2.is_open + assert h2.host == self.hdfs.host + assert h2.port == self.hdfs.port + assert h2.user == self.hdfs.user + assert h2.kerb_ticket == self.hdfs.kerb_ticket + assert h2.driver == self.hdfs.driver + # smoketest unpickled client works + h2.ls(self.tmp_path) + def test_cat(self): path = pjoin(self.tmp_path, 'cat-test') @@ -299,7 +316,7 @@ class TestLibHdfs(HdfsTestCases, unittest.TestCase): @classmethod def check_driver(cls): if not pa.have_libhdfs(): - pytest.fail('No libhdfs available on system') + pytest.skip('No libhdfs available on system') def test_orphaned_file(self): hdfs = hdfs_test_client() @@ -318,4 +335,4 @@ class TestLibHdfs3(HdfsTestCases, unittest.TestCase): @classmethod def check_driver(cls): if not pa.have_libhdfs3(): - pytest.fail('No libhdfs3 available on system') + pytest.skip('No libhdfs3 available on system') diff --git a/python/pyarrow/tests/test_io.py b/python/pyarrow/tests/test_io.py index e60dd35de66fe..da26b101db260 100644 --- a/python/pyarrow/tests/test_io.py +++ b/python/pyarrow/tests/test_io.py @@ -15,7 +15,7 @@ # specific language governing permissions and limitations # under the License. -from io import BytesIO +from io import BytesIO, TextIOWrapper import gc import os import pytest @@ -257,7 +257,7 @@ def test_inmemory_write_after_closed(): f.write(b'ok') f.get_result() - with pytest.raises(IOError): + with pytest.raises(ValueError): f.write(b'not ok') @@ -482,24 +482,106 @@ def test_native_file_modes(tmpdir): with pa.OSFile(path, mode='r') as f: assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() with pa.OSFile(path, mode='rb') as f: assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() with pa.OSFile(path, mode='w') as f: assert f.mode == 'wb' + assert not f.readable() + assert f.writable() + assert not f.seekable() with pa.OSFile(path, mode='wb') as f: assert f.mode == 'wb' + assert not f.readable() + assert f.writable() + assert not f.seekable() with open(path, 'wb') as f: f.write(b'foooo') with pa.memory_map(path, 'r') as f: assert f.mode == 'rb' + assert f.readable() + assert not f.writable() + assert f.seekable() with pa.memory_map(path, 'r+') as f: assert f.mode == 'rb+' + assert f.readable() + assert f.writable() + assert f.seekable() with pa.memory_map(path, 'r+b') as f: assert f.mode == 'rb+' + assert f.readable() + assert f.writable() + assert f.seekable() + + +def test_native_file_raises_ValueError_after_close(tmpdir): + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(b'foooo') + + with pa.OSFile(path, mode='rb') as os_file: + assert not os_file.closed + assert os_file.closed + + with pa.memory_map(path, mode='rb') as mmap_file: + assert not mmap_file.closed + assert mmap_file.closed + + files = [os_file, + mmap_file] + + methods = [('tell', ()), + ('seek', (0,)), + ('size', ()), + ('flush', ()), + ('readable', ()), + ('writable', ()), + ('seekable', ())] + + for f in files: + for method, args in methods: + with pytest.raises(ValueError): + getattr(f, method)(*args) + + +def test_native_file_TextIOWrapper(tmpdir): + data = (u'foooo\n' + u'barrr\n' + u'bazzz\n') + + path = os.path.join(str(tmpdir), guid()) + with open(path, 'wb') as f: + f.write(data.encode('utf-8')) + + with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil: + assert fil.readable() + res = fil.read() + assert res == data + assert fil.closed + + with TextIOWrapper(pa.OSFile(path, mode='rb')) as fil: + # Iteration works + lines = list(fil) + assert ''.join(lines) == data + + # Writing + path2 = os.path.join(str(tmpdir), guid()) + with TextIOWrapper(pa.OSFile(path2, mode='wb')) as fil: + assert fil.writable() + fil.write(data) + + with TextIOWrapper(pa.OSFile(path2, mode='rb')) as fil: + res = fil.read() + assert res == data diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py index c2bb31c9bcf51..7c2edb378df61 100644 --- a/python/pyarrow/tests/test_parquet.py +++ b/python/pyarrow/tests/test_parquet.py @@ -748,6 +748,28 @@ def test_sanitized_spark_field_names(): assert result.schema[0].name == expected_name +def _roundtrip_pandas_dataframe(df, write_kwargs): + table = pa.Table.from_pandas(df) + + buf = io.BytesIO() + _write_table(table, buf, **write_kwargs) + + buf.seek(0) + table1 = _read_table(buf) + return table1.to_pandas() + + +@parquet +def test_spark_flavor_preserves_pandas_metadata(): + df = _test_dataframe(size=100) + df.index = np.arange(0, 10 * len(df), 10) + df.index.name = 'foo' + + result = _roundtrip_pandas_dataframe(df, {'version': '2.0', + 'flavor': 'spark'}) + tm.assert_frame_equal(result, df) + + @parquet def test_fixed_size_binary(): t0 = pa.binary(10) diff --git a/python/pyarrow/tests/test_schema.py b/python/pyarrow/tests/test_schema.py index dbca139e20570..90efe3f7e950a 100644 --- a/python/pyarrow/tests/test_schema.py +++ b/python/pyarrow/tests/test_schema.py @@ -154,8 +154,21 @@ def test_time_types(): pa.time64('s') -def test_type_from_numpy_dtype_timestamps(): +def test_from_numpy_dtype(): cases = [ + (np.dtype('bool'), pa.bool_()), + (np.dtype('int8'), pa.int8()), + (np.dtype('int16'), pa.int16()), + (np.dtype('int32'), pa.int32()), + (np.dtype('int64'), pa.int64()), + (np.dtype('uint8'), pa.uint8()), + (np.dtype('uint16'), pa.uint16()), + (np.dtype('uint32'), pa.uint32()), + (np.dtype('float16'), pa.float16()), + (np.dtype('float32'), pa.float32()), + (np.dtype('float64'), pa.float64()), + (np.dtype('U'), pa.string()), + (np.dtype('S'), pa.binary()), (np.dtype('datetime64[s]'), pa.timestamp('s')), (np.dtype('datetime64[ms]'), pa.timestamp('ms')), (np.dtype('datetime64[us]'), pa.timestamp('us')), @@ -166,6 +179,18 @@ def test_type_from_numpy_dtype_timestamps(): result = pa.from_numpy_dtype(dt) assert result == pt + # Things convertible to numpy dtypes work + assert pa.from_numpy_dtype('U') == pa.string() + assert pa.from_numpy_dtype(np.unicode) == pa.string() + assert pa.from_numpy_dtype('int32') == pa.int32() + assert pa.from_numpy_dtype(bool) == pa.bool_() + + with pytest.raises(NotImplementedError): + pa.from_numpy_dtype(np.dtype('O')) + + with pytest.raises(TypeError): + pa.from_numpy_dtype('not_convertible_to_dtype') + def test_field(): t = pa.string() diff --git a/python/pyarrow/tests/test_serialization.py b/python/pyarrow/tests/test_serialization.py index 6116556386b1a..7a420106f9fb6 100644 --- a/python/pyarrow/tests/test_serialization.py +++ b/python/pyarrow/tests/test_serialization.py @@ -190,8 +190,7 @@ class CustomError(Exception): def make_serialization_context(): - - context = pa._default_serialization_context + context = pa.default_serialization_context() context.register_type(Foo, "Foo") context.register_type(Bar, "Bar") @@ -207,29 +206,35 @@ def make_serialization_context(): return context -serialization_context = make_serialization_context() +global_serialization_context = make_serialization_context() + +def serialization_roundtrip(value, scratch_buffer, + context=global_serialization_context): + writer = pa.FixedSizeBufferWriter(scratch_buffer) + pa.serialize_to(value, writer, context=context) -def serialization_roundtrip(value, f, ctx=serialization_context): - f.seek(0) - pa.serialize_to(value, f, ctx) - f.seek(0) - result = pa.deserialize_from(f, None, ctx) + reader = pa.BufferReader(scratch_buffer) + result = pa.deserialize_from(reader, None, context=context) assert_equal(value, result) - _check_component_roundtrip(value) + _check_component_roundtrip(value, context=context) -def _check_component_roundtrip(value): +def _check_component_roundtrip(value, context=global_serialization_context): # Test to/from components - serialized = pa.serialize(value) + serialized = pa.serialize(value, context=context) components = serialized.to_components() from_comp = pa.SerializedPyObject.from_components(components) - recons = from_comp.deserialize() + recons = from_comp.deserialize(context=context) assert_equal(value, recons) @pytest.yield_fixture(scope='session') +def large_buffer(size=100*1024*1024): + return pa.allocate_buffer(size) + + def large_memory_map(tmpdir_factory, size=100*1024*1024): path = (tmpdir_factory.mktemp('data') .join('pyarrow-serialization-tmp-file').strpath) @@ -243,11 +248,11 @@ def large_memory_map(tmpdir_factory, size=100*1024*1024): return path -def test_primitive_serialization(large_memory_map): - with pa.memory_map(large_memory_map, mode="r+") as mmap: - for obj in PRIMITIVE_OBJECTS: - serialization_roundtrip(obj, mmap) - serialization_roundtrip(obj, mmap, pa.pandas_serialization_context) +def test_primitive_serialization(large_buffer): + for obj in PRIMITIVE_OBJECTS: + serialization_roundtrip(obj, large_buffer) + serialization_roundtrip(obj, large_buffer, + pa.pandas_serialization_context()) def test_serialize_to_buffer(): @@ -258,34 +263,31 @@ def test_serialize_to_buffer(): assert_equal(value, result) -def test_complex_serialization(large_memory_map): - with pa.memory_map(large_memory_map, mode="r+") as mmap: - for obj in COMPLEX_OBJECTS: - serialization_roundtrip(obj, mmap) +def test_complex_serialization(large_buffer): + for obj in COMPLEX_OBJECTS: + serialization_roundtrip(obj, large_buffer) -def test_custom_serialization(large_memory_map): - with pa.memory_map(large_memory_map, mode="r+") as mmap: - for obj in CUSTOM_OBJECTS: - serialization_roundtrip(obj, mmap) +def test_custom_serialization(large_buffer): + for obj in CUSTOM_OBJECTS: + serialization_roundtrip(obj, large_buffer) -def test_default_dict_serialization(large_memory_map): +def test_default_dict_serialization(large_buffer): pytest.importorskip("cloudpickle") - with pa.memory_map(large_memory_map, mode="r+") as mmap: - obj = defaultdict(lambda: 0, [("hello", 1), ("world", 2)]) - serialization_roundtrip(obj, mmap) + + obj = defaultdict(lambda: 0, [("hello", 1), ("world", 2)]) + serialization_roundtrip(obj, large_buffer) -def test_numpy_serialization(large_memory_map): - with pa.memory_map(large_memory_map, mode="r+") as mmap: - for t in ["bool", "int8", "uint8", "int16", "uint16", "int32", - "uint32", "float16", "float32", "float64"]: - obj = np.random.randint(0, 10, size=(100, 100)).astype(t) - serialization_roundtrip(obj, mmap) +def test_numpy_serialization(large_buffer): + for t in ["bool", "int8", "uint8", "int16", "uint16", "int32", + "uint32", "float16", "float32", "float64"]: + obj = np.random.randint(0, 10, size=(100, 100)).astype(t) + serialization_roundtrip(obj, large_buffer) -def test_datetime_serialization(large_memory_map): +def test_datetime_serialization(large_buffer): data = [ # Principia Mathematica published datetime.datetime(year=1687, month=7, day=5), @@ -309,32 +311,35 @@ def test_datetime_serialization(large_memory_map): datetime.datetime(year=1970, month=1, day=3, hour=4, minute=0, second=0) ] - with pa.memory_map(large_memory_map, mode="r+") as mmap: - for d in data: - serialization_roundtrip(d, mmap) + for d in data: + serialization_roundtrip(d, large_buffer) -def test_torch_serialization(large_memory_map): +def test_torch_serialization(large_buffer): pytest.importorskip("torch") import torch - with pa.memory_map(large_memory_map, mode="r+") as mmap: - # These are the only types that are supported for the - # PyTorch to NumPy conversion - for t in ["float32", "float64", - "uint8", "int16", "int32", "int64"]: - obj = torch.from_numpy(np.random.randn(1000).astype(t)) - serialization_roundtrip(obj, mmap) - - -def test_numpy_immutable(large_memory_map): - with pa.memory_map(large_memory_map, mode="r+") as mmap: - obj = np.zeros([10]) - mmap.seek(0) - pa.serialize_to(obj, mmap, serialization_context) - mmap.seek(0) - result = pa.deserialize_from(mmap, None, serialization_context) - with pytest.raises(ValueError): - result[0] = 1.0 + + serialization_context = pa.default_serialization_context() + pa.register_torch_serialization_handlers(serialization_context) + # These are the only types that are supported for the + # PyTorch to NumPy conversion + for t in ["float32", "float64", + "uint8", "int16", "int32", "int64"]: + obj = torch.from_numpy(np.random.randn(1000).astype(t)) + serialization_roundtrip(obj, large_buffer, + context=serialization_context) + + +def test_numpy_immutable(large_buffer): + obj = np.zeros([10]) + + writer = pa.FixedSizeBufferWriter(large_buffer) + pa.serialize_to(obj, writer, global_serialization_context) + + reader = pa.BufferReader(large_buffer) + result = pa.deserialize_from(reader, None, global_serialization_context) + with pytest.raises(ValueError): + result[0] = 1.0 # see https://issues.apache.org/jira/browse/ARROW-1695 @@ -350,12 +355,12 @@ def serialize_dummy_class(obj): def deserialize_dummy_class(serialized_obj): return serialized_obj - pa._default_serialization_context.register_type( - DummyClass, "DummyClass", - custom_serializer=serialize_dummy_class, - custom_deserializer=deserialize_dummy_class) + context = pa.default_serialization_context() + context.register_type(DummyClass, "DummyClass", + custom_serializer=serialize_dummy_class, + custom_deserializer=deserialize_dummy_class) - pa.serialize(DummyClass()) + pa.serialize(DummyClass(), context=context) def test_buffer_serialization(): @@ -369,13 +374,14 @@ def serialize_buffer_class(obj): def deserialize_buffer_class(serialized_obj): return serialized_obj - pa._default_serialization_context.register_type( + context = pa.default_serialization_context() + context.register_type( BufferClass, "BufferClass", custom_serializer=serialize_buffer_class, custom_deserializer=deserialize_buffer_class) - b = pa.serialize(BufferClass()).to_buffer() - assert pa.deserialize(b).to_pybytes() == b"hello" + b = pa.serialize(BufferClass(), context=context).to_buffer() + assert pa.deserialize(b, context=context).to_pybytes() == b"hello" @pytest.mark.skip(reason="extensive memory requirements") @@ -484,15 +490,16 @@ def test_serialize_subclasses(): # with register_type will result in faster and more memory # efficient serialization. - serialization_context.register_type( + context = pa.default_serialization_context() + context.register_type( Serializable, "Serializable", custom_serializer=serialize_serializable, custom_deserializer=deserialize_serializable) a = SerializableClass() - serialized = pa.serialize(a) + serialized = pa.serialize(a, context=context) - deserialized = serialized.deserialize() + deserialized = serialized.deserialize(context=context) assert type(deserialized).__name__ == SerializableClass.__name__ assert deserialized.value == 3 @@ -554,4 +561,43 @@ def test_deserialize_buffer_in_different_process(): dir_path = os.path.dirname(os.path.realpath(__file__)) python_file = os.path.join(dir_path, 'deserialize_buffer.py') - subprocess.check_call(['python', python_file, f.name]) + subprocess.check_call([sys.executable, python_file, f.name]) + + +def test_set_pickle(): + # Use a custom type to trigger pickling. + class Foo(object): + pass + + context = pa.SerializationContext() + context.register_type(Foo, 'Foo', pickle=True) + + test_object = Foo() + + # Define a custom serializer and deserializer to use in place of pickle. + + def dumps1(obj): + return b'custom' + + def loads1(serialized_obj): + return serialized_obj + b' serialization 1' + + # Test that setting a custom pickler changes the behavior. + context.set_pickle(dumps1, loads1) + serialized = pa.serialize(test_object, context=context).to_buffer() + deserialized = pa.deserialize(serialized.to_pybytes(), context=context) + assert deserialized == b'custom serialization 1' + + # Define another custom serializer and deserializer. + + def dumps2(obj): + return b'custom' + + def loads2(serialized_obj): + return serialized_obj + b' serialization 2' + + # Test that setting another custom pickler changes the behavior again. + context.set_pickle(dumps2, loads2) + serialized = pa.serialize(test_object, context=context).to_buffer() + deserialized = pa.deserialize(serialized.to_pybytes(), context=context) + assert deserialized == b'custom serialization 2' diff --git a/python/pyarrow/tests/test_types.py b/python/pyarrow/tests/test_types.py index 68dc499cf48b4..ad683e9a2ea00 100644 --- a/python/pyarrow/tests/test_types.py +++ b/python/pyarrow/tests/test_types.py @@ -184,3 +184,13 @@ def test_types_hashable(): ]) def test_exact_primitive_types(t, check_func): assert check_func(t) + + +def test_fixed_size_binary_byte_width(): + ty = pa.binary(5) + assert ty.byte_width == 5 + + +def test_decimal_byte_width(): + ty = pa.decimal128(19, 4) + assert ty.byte_width == 16 diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi index 1563b57855cd9..849a0e016a60d 100644 --- a/python/pyarrow/types.pxi +++ b/python/pyarrow/types.pxi @@ -293,7 +293,7 @@ cdef class FixedSizeBinaryType(DataType): cdef class Decimal128Type(FixedSizeBinaryType): cdef void init(self, const shared_ptr[CDataType]& type): - DataType.init(self, type) + FixedSizeBinaryType.init(self, type) self.decimal128_type = type.get() def __getstate__(self): @@ -1207,6 +1207,7 @@ def from_numpy_dtype(object dtype): Convert NumPy dtype to pyarrow.DataType """ cdef shared_ptr[CDataType] c_type + dtype = np.dtype(dtype) with nogil: check_status(NumPyDtypeToArrow(dtype, &c_type)) diff --git a/python/setup.py b/python/setup.py index 3d3831dc048c6..cfc771fe870ab 100644 --- a/python/setup.py +++ b/python/setup.py @@ -17,10 +17,12 @@ # specific language governing permissions and limitations # under the License. +import contextlib import glob import os import os.path as osp import re +import shlex import shutil import sys @@ -46,6 +48,16 @@ setup_dir = os.path.abspath(os.path.dirname(__file__)) +@contextlib.contextmanager +def changed_dir(dirname): + oldcwd = os.getcwd() + os.chdir(dirname) + try: + yield + finally: + os.chdir(oldcwd) + + class clean(_clean): def run(self): @@ -58,6 +70,7 @@ def run(self): class build_ext(_build_ext): + _found_names = () def build_extensions(self): numpy_incl = pkg_resources.resource_filename('numpy', 'core/include') @@ -128,163 +141,160 @@ def _run_cmake(self): # The staging directory for the module being built build_temp = pjoin(os.getcwd(), self.build_temp) build_lib = os.path.join(os.getcwd(), self.build_lib) - - # Change to the build directory saved_cwd = os.getcwd() + if not os.path.isdir(self.build_temp): self.mkpath(self.build_temp) - os.chdir(self.build_temp) - - # Detect if we built elsewhere - if os.path.isfile('CMakeCache.txt'): - cachefile = open('CMakeCache.txt', 'r') - cachedir = re.search('CMAKE_CACHEFILE_DIR:INTERNAL=(.*)', - cachefile.read()).group(1) - cachefile.close() - if (cachedir != build_temp): - return - - static_lib_option = '' - - cmake_options = [ - '-DPYTHON_EXECUTABLE=%s' % sys.executable, - static_lib_option, - ] - if self.with_parquet: - cmake_options.append('-DPYARROW_BUILD_PARQUET=on') - if self.with_static_parquet: - cmake_options.append('-DPYARROW_PARQUET_USE_SHARED=off') - if not self.with_static_boost: - cmake_options.append('-DPYARROW_BOOST_USE_SHARED=on') - - if self.with_plasma: - cmake_options.append('-DPYARROW_BUILD_PLASMA=on') - - if self.with_orc: - cmake_options.append('-DPYARROW_BUILD_ORC=on') - - if len(self.cmake_cxxflags) > 0: - cmake_options.append('-DPYARROW_CXXFLAGS="{0}"' - .format(self.cmake_cxxflags)) - - if self.bundle_arrow_cpp: - cmake_options.append('-DPYARROW_BUNDLE_ARROW_CPP=ON') - # ARROW-1090: work around CMake rough edges - if 'ARROW_HOME' in os.environ and sys.platform != 'win32': - pkg_config = pjoin(os.environ['ARROW_HOME'], 'lib', - 'pkgconfig') - os.environ['PKG_CONFIG_PATH'] = pkg_config - del os.environ['ARROW_HOME'] - - cmake_options.append('-DCMAKE_BUILD_TYPE={0}' - .format(self.build_type.lower())) - - if sys.platform != 'win32': - cmake_command = (['cmake', self.extra_cmake_args] + - cmake_options + [source]) - - print("-- Runnning cmake for pyarrow") - self.spawn(cmake_command) - print("-- Finished cmake for pyarrow") - args = ['make'] - if os.environ.get('PYARROW_BUILD_VERBOSE', '0') == '1': - args.append('VERBOSE=1') - - if 'PYARROW_PARALLEL' in os.environ: - args.append('-j{0}'.format(os.environ['PYARROW_PARALLEL'])) - print("-- Running cmake --build for pyarrow") - self.spawn(args) - print("-- Finished cmake --build for pyarrow") - else: - import shlex - cmake_generator = 'Visual Studio 14 2015 Win64' - if not is_64_bit: - raise RuntimeError('Not supported on 32-bit Windows') + # Change to the build directory + with changed_dir(self.build_temp): + # Detect if we built elsewhere + if os.path.isfile('CMakeCache.txt'): + cachefile = open('CMakeCache.txt', 'r') + cachedir = re.search('CMAKE_CACHEFILE_DIR:INTERNAL=(.*)', + cachefile.read()).group(1) + cachefile.close() + if (cachedir != build_temp): + return + + static_lib_option = '' + + cmake_options = [ + '-DPYTHON_EXECUTABLE=%s' % sys.executable, + static_lib_option, + ] + + if self.with_parquet: + cmake_options.append('-DPYARROW_BUILD_PARQUET=on') + if self.with_static_parquet: + cmake_options.append('-DPYARROW_PARQUET_USE_SHARED=off') + if not self.with_static_boost: + cmake_options.append('-DPYARROW_BOOST_USE_SHARED=on') + + if self.with_plasma: + cmake_options.append('-DPYARROW_BUILD_PLASMA=on') + + if self.with_orc: + cmake_options.append('-DPYARROW_BUILD_ORC=on') + + if len(self.cmake_cxxflags) > 0: + cmake_options.append('-DPYARROW_CXXFLAGS="{0}"' + .format(self.cmake_cxxflags)) + + if self.bundle_arrow_cpp: + cmake_options.append('-DPYARROW_BUNDLE_ARROW_CPP=ON') + # ARROW-1090: work around CMake rough edges + if 'ARROW_HOME' in os.environ and sys.platform != 'win32': + pkg_config = pjoin(os.environ['ARROW_HOME'], 'lib', + 'pkgconfig') + os.environ['PKG_CONFIG_PATH'] = pkg_config + del os.environ['ARROW_HOME'] + + cmake_options.append('-DCMAKE_BUILD_TYPE={0}' + .format(self.build_type.lower())) - # Generate the build files extra_cmake_args = shlex.split(self.extra_cmake_args) - cmake_command = (['cmake'] + extra_cmake_args + - cmake_options + - [source, '-G', cmake_generator]) - if "-G" in self.extra_cmake_args: - cmake_command = cmake_command[:-2] - - print("-- Runnning cmake for pyarrow") - self.spawn(cmake_command) - print("-- Finished cmake for pyarrow") - # Do the build - print("-- Running cmake --build for pyarrow") - self.spawn(['cmake', '--build', '.', '--config', self.build_type]) - print("-- Finished cmake --build for pyarrow") - - if self.inplace: - # a bit hacky - build_lib = saved_cwd - - # Move the libraries to the place expected by the Python - # build - - try: - os.makedirs(pjoin(build_lib, 'pyarrow')) - except OSError: - pass + if sys.platform != 'win32': + cmake_command = (['cmake'] + extra_cmake_args + + cmake_options + [source]) + + print("-- Runnning cmake for pyarrow") + self.spawn(cmake_command) + print("-- Finished cmake for pyarrow") + args = ['make'] + if os.environ.get('PYARROW_BUILD_VERBOSE', '0') == '1': + args.append('VERBOSE=1') + + if 'PYARROW_PARALLEL' in os.environ: + args.append('-j{0}'.format(os.environ['PYARROW_PARALLEL'])) + print("-- Running cmake --build for pyarrow") + self.spawn(args) + print("-- Finished cmake --build for pyarrow") + else: + cmake_generator = 'Visual Studio 14 2015 Win64' + if not is_64_bit: + raise RuntimeError('Not supported on 32-bit Windows') + + # Generate the build files + cmake_command = (['cmake'] + extra_cmake_args + + cmake_options + + [source, '-G', cmake_generator]) + if "-G" in self.extra_cmake_args: + cmake_command = cmake_command[:-2] + + print("-- Runnning cmake for pyarrow") + self.spawn(cmake_command) + print("-- Finished cmake for pyarrow") + # Do the build + print("-- Running cmake --build for pyarrow") + self.spawn(['cmake', '--build', '.', '--config', self.build_type]) + print("-- Finished cmake --build for pyarrow") + + if self.inplace: + # a bit hacky + build_lib = saved_cwd + + # Move the libraries to the place expected by the Python + # build - if sys.platform == 'win32': - build_prefix = '' - else: - build_prefix = self.build_type + try: + os.makedirs(pjoin(build_lib, 'pyarrow')) + except OSError: + pass - if self.bundle_arrow_cpp: - print(pjoin(build_lib, 'pyarrow')) - move_shared_libs(build_prefix, build_lib, "arrow") - move_shared_libs(build_prefix, build_lib, "arrow_python") + if sys.platform == 'win32': + build_prefix = '' + else: + build_prefix = self.build_type + + if self.bundle_arrow_cpp: + print(pjoin(build_lib, 'pyarrow')) + move_shared_libs(build_prefix, build_lib, "arrow") + move_shared_libs(build_prefix, build_lib, "arrow_python") + if self.with_plasma: + move_shared_libs(build_prefix, build_lib, "plasma") + if self.with_parquet and not self.with_static_parquet: + move_shared_libs(build_prefix, build_lib, "parquet") + + print('Bundling includes: ' + pjoin(build_prefix, 'include')) + if os.path.exists(pjoin(build_lib, 'pyarrow', 'include')): + shutil.rmtree(pjoin(build_lib, 'pyarrow', 'include')) + shutil.move(pjoin(build_prefix, 'include'), + pjoin(build_lib, 'pyarrow')) + + # Move the built C-extension to the place expected by the Python build + self._found_names = [] + for name in self.CYTHON_MODULE_NAMES: + built_path = self.get_ext_built(name) + if not os.path.exists(built_path): + print(built_path) + if self._failure_permitted(name): + print('Cython module {0} failure permitted'.format(name)) + continue + raise RuntimeError('pyarrow C-extension failed to build:', + os.path.abspath(built_path)) + + ext_path = pjoin(build_lib, self._get_cmake_ext_path(name)) + if os.path.exists(ext_path): + os.remove(ext_path) + self.mkpath(os.path.dirname(ext_path)) + print('Moving built C-extension', built_path, + 'to build path', ext_path) + shutil.move(self.get_ext_built(name), ext_path) + self._found_names.append(name) + + if os.path.exists(self.get_ext_built_api_header(name)): + shutil.move(self.get_ext_built_api_header(name), + pjoin(os.path.dirname(ext_path), name + '_api.h')) + + # Move the plasma store if self.with_plasma: - move_shared_libs(build_prefix, build_lib, "plasma") - if self.with_parquet and not self.with_static_parquet: - move_shared_libs(build_prefix, build_lib, "parquet") - - print('Bundling includes: ' + pjoin(build_prefix, 'include')) - if os.path.exists(pjoin(build_lib, 'pyarrow', 'include')): - shutil.rmtree(pjoin(build_lib, 'pyarrow', 'include')) - shutil.move(pjoin(build_prefix, 'include'), - pjoin(build_lib, 'pyarrow')) - - # Move the built C-extension to the place expected by the Python build - self._found_names = [] - for name in self.CYTHON_MODULE_NAMES: - built_path = self.get_ext_built(name) - if not os.path.exists(built_path): - print(built_path) - if self._failure_permitted(name): - print('Cython module {0} failure permitted'.format(name)) - continue - raise RuntimeError('pyarrow C-extension failed to build:', - os.path.abspath(built_path)) - - ext_path = pjoin(build_lib, self._get_cmake_ext_path(name)) - if os.path.exists(ext_path): - os.remove(ext_path) - self.mkpath(os.path.dirname(ext_path)) - print('Moving built C-extension', built_path, - 'to build path', ext_path) - shutil.move(self.get_ext_built(name), ext_path) - self._found_names.append(name) - - if os.path.exists(self.get_ext_built_api_header(name)): - shutil.move(self.get_ext_built_api_header(name), - pjoin(os.path.dirname(ext_path), name + '_api.h')) - - # Move the plasma store - if self.with_plasma: - build_py = self.get_finalized_command('build_py') - source = os.path.join(self.build_type, "plasma_store") - target = os.path.join(build_lib, - build_py.get_package_dir('pyarrow'), - "plasma_store") - shutil.move(source, target) - - os.chdir(saved_cwd) + build_py = self.get_finalized_command('build_py') + source = os.path.join(self.build_type, "plasma_store") + target = os.path.join(build_lib, + build_py.get_package_dir('pyarrow'), + "plasma_store") + shutil.move(source, target) def _failure_permitted(self, name): if name == '_parquet' and not self.with_parquet: diff --git a/site/_data/versions.yml b/site/_data/versions.yml new file mode 100644 index 0000000000000..0d04183868dcf --- /dev/null +++ b/site/_data/versions.yml @@ -0,0 +1,29 @@ +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to you under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Database of contributors to Apache Arrow (WIP) +# Blogs and other pages use this data +# +current: + number: '0.8.0' + date: '18 December 2017' + git-tag: '1d689e5' + github-tag-link: 'https://github.com/apache/arrow/releases/tag/apache-arrow-0.8.0' + release-notes: 'http://arrow.apache.org/release/0.8.0.html' + mirrors: 'https://www.apache.org/dyn/closer.cgi/arrow/arrow-0.8.0/' + mirrors-tar: 'https://www.apache.org/dyn/closer.cgi/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz' + java-artifacts: 'http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.arrow%22%20AND%20v%3A%220.8.0%22' + asc: 'https://www.apache.org/dist/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz.asc' + sha512: 'https://www.apache.org/dist/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz.sha512' \ No newline at end of file diff --git a/site/img/copy.png b/site/img/copy.png index a1e04999eb3fd..55ff71ece1e59 100644 Binary files a/site/img/copy.png and b/site/img/copy.png differ diff --git a/site/img/copy2.png b/site/img/copy2.png deleted file mode 100644 index 7869daddefe9f..0000000000000 Binary files a/site/img/copy2.png and /dev/null differ diff --git a/site/img/shared.png b/site/img/shared.png index 7869daddefe9f..b079ad0c6b4e5 100644 Binary files a/site/img/shared.png and b/site/img/shared.png differ diff --git a/site/img/shared2.png b/site/img/shared2.png deleted file mode 100644 index a1e04999eb3fd..0000000000000 Binary files a/site/img/shared2.png and /dev/null differ diff --git a/site/index.html b/site/index.html index 87995cbabed48..ec80075c59200 100644 --- a/site/index.html +++ b/site/index.html @@ -1,74 +1,78 @@ --- layout: default --- -
-
+
+

Apache Arrow

-

Powering Columnar In-Memory Analytics

+

A cross-language development platform for in-memory data

- Join Mailing List - Install (0.8.0 Release - December 18, 2017) + Join Mailing List + Install ({{site.data.versions['current'].number}} Release - {{site.data.versions['current'].date}})

-
-

See Latest News

-
+
+
+
+

+ See Latest News +

+
+
+
+
+

Apache Arrow is a cross-language development platform for in-memory data. It specifies a standardized language-independent columnar memory format for flat and hierarchical data, organized for efficient analytic operations on modern hardware. It also provides computational libraries and zero-copy streaming messaging and interprocess communication. Languages currently supported include C, C++, Java, JavaScript, Python, and Ruby.

+
+
+
-

Fast

-

Apache Arrow™ enables execution engines to take advantage of - the latest SIMD (Single input multiple data) operations included in modern - processors, for native vectorized optimization of analytical data - processing. Columnar layout is optimized for data locality for better - performance on modern hardware like CPUs and GPUs.

- -

The Arrow memory format supports zero-copy reads - for lightning-fast data access without serialization overhead.

- +

Fast

+

Apache Arrow™ enables execution engines to take advantage of the latest SIMD (Single input multiple data) operations included in modern processors, for native vectorized optimization of analytical data processing. Columnar layout is optimized for data locality for better performance on modern hardware like CPUs and GPUs.

+

The Arrow memory format supports zero-copy reads for lightning-fast data access without serialization overhead.

-

Flexible

-

Arrow acts as a new high-performance interface between various - systems. It is also focused on supporting a wide variety of - industry-standard programming languages. Java, C, C++, Python, Ruby, - and JavaScript implementations are in progress and more languages are - welcome.

+

Flexible

+

Arrow acts as a new high-performance interface between various systems. It is also focused on supporting a wide variety of industry-standard programming languages. Java, C, C++, Python, Ruby, and JavaScript implementations are in progress and more languages are welcome. +

-

Standard

-

Apache Arrow is backed by key developers of 13 major open source - projects, including Calcite, Cassandra, Drill, Hadoop, HBase, Ibis, - Impala, Kudu, Pandas, Parquet, Phoenix, Spark, and Storm making it - the de-facto standard for columnar in-memory analytics.

- -

Learn more about projects that are Powered By Apache Arrow

+

Standard

+

Apache Arrow is backed by key developers of 13 major open source projects, including Calcite, Cassandra, Drill, Hadoop, HBase, Ibis, Impala, Kudu, Pandas, Parquet, Phoenix, Spark, and Storm making it the de-facto standard for columnar in-memory analytics.

+

Learn more about projects that are Powered By Apache Arrow

+
+
+ +
+
+

Performance Advantage of Columnar In-Memory

+
+
+ SIMD
-
+
-

Performance Advantage of Columnar In-Memory

-
- SIMD +
+
+

Advantages of a Common Data Layer

+
+
+ common data layer +
    +
  • Each system has its own internal memory format
  • +
  • 70-80% computation wasted on serialization and deserialization
  • +
  • Similar functionality implemented in multiple projects
  • +
+
+
+ common data layer +
    +
  • All systems utilize the same memory format
  • +
  • No overhead for cross-system communication
  • +
  • Projects can share functionality (eg, Parquet-to-Arrow reader)
  • +
+
+
-

Advantages of a Common Data Layer

+ -
-
-common data layer -
    -
  • Each system has its own internal memory format
  • -
  • 70-80% computation wasted on serialization and deserialization
  • -
  • Similar functionality implemented in multiple projects
  • -
-
-
-common data layer -
    -
  • All systems utilize the same memory format
  • -
  • No overhead for cross-system communication
  • -
  • Projects can share functionality (eg, Parquet-to-Arrow reader)
  • -
-
-
-
- + diff --git a/site/install.md b/site/install.md index ec30e0469cdc1..f795299676eb5 100644 --- a/site/install.md +++ b/site/install.md @@ -20,9 +20,9 @@ limitations under the License. {% endcomment %} --> -## Current Version: 0.8.0 +## Current Version: {{site.data.versions['current'].number}} -### Released: 18 December 2017 +### Released: {{site.data.versions['current'].date}} See the [release notes][10] for more about what's new. @@ -30,7 +30,7 @@ See the [release notes][10] for more about what's new. * **Source Release**: [apache-arrow-0.8.0.tar.gz][6] * **Verification**: [sha512][3], [asc][7] ([verification instructions][12]) -* [Git tag 1d689e5][2] +* [Git tag {{site.data.versions['current'].git-tag}}][2] * [PGP keys for release signatures][11] ### Java Packages @@ -145,15 +145,15 @@ These repositories are managed at [red-data-tools/arrow-packages][9]. If you have any feedback, please send it to the project instead of Apache Arrow project. -[1]: https://www.apache.org/dyn/closer.cgi/arrow/arrow-0.8.0/ -[2]: https://github.com/apache/arrow/releases/tag/apache-arrow-0.8.0 -[3]: https://www.apache.org/dist/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz.sha512 -[4]: http://search.maven.org/#search%7Cga%7C1%7Cg%3A%22org.apache.arrow%22%20AND%20v%3A%220.8.0%22 +[1]: {{site.data.versions['current'].mirrors}} +[2]: {{site.data.versions['current'].github-tag-link}} +[3]: {{site.data.versions['current'].sha512}} +[4]: {{site.data.versions['current'].java-artifacts}} [5]: http://conda-forge.github.io -[6]: https://www.apache.org/dyn/closer.cgi/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz -[7]: https://www.apache.org/dist/arrow/arrow-0.8.0/apache-arrow-0.8.0.tar.gz.asc +[6]: {{site.data.versions['current'].mirrors-tar}} +[7]: {{site.data.versions['current'].asc}} [8]: https://github.com/red-data-tools/parquet-glib [9]: https://github.com/red-data-tools/arrow-packages -[10]: http://arrow.apache.org/release/0.8.0.html +[10]: {{site.data.versions['current'].release-notes}} [11]: http://www.apache.org/dist/arrow/KEYS [12]: https://www.apache.org/dyn/closer.cgi#verify \ No newline at end of file