Skip to content

Commit

Permalink
Add support to run specific unittests and/or doctests in python/run-t…
Browse files Browse the repository at this point in the history
…ests script
  • Loading branch information
HyukjinKwon committed Dec 3, 2018
1 parent 676bbb2 commit 44c622b
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 32 deletions.
2 changes: 0 additions & 2 deletions python/run-tests-with-coverage
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@ export SPARK_CONF_DIR="$COVERAGE_DIR/conf"
# This environment variable enables the coverage.
export COVERAGE_PROCESS_START="$FWDIR/.coveragerc"

# If you'd like to run a specific unittest class, you could do such as
# SPARK_TESTING=1 ../bin/pyspark pyspark.sql.tests VectorizedUDFTests
./run-tests "$@"

# Don't run coverage for the coverage command itself
Expand Down
86 changes: 56 additions & 30 deletions python/run-tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from __future__ import print_function
import logging
from optparse import OptionParser
from optparse import OptionParser, OptionGroup
import os
import re
import shutil
Expand Down Expand Up @@ -93,17 +93,18 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
"pyspark-shell"
]
env["PYSPARK_SUBMIT_ARGS"] = " ".join(spark_args)

LOGGER.info("Starting test(%s): %s", pyspark_python, test_name)
str_test_name = " ".join(test_name)
LOGGER.info("Starting test(%s): %s", pyspark_python, str_test_name)
start_time = time.time()
try:
per_test_output = tempfile.TemporaryFile()
retcode = subprocess.Popen(
[os.path.join(SPARK_HOME, "bin/pyspark"), test_name],
(os.path.join(SPARK_HOME, "bin/pyspark"), ) + test_name,
stderr=per_test_output, stdout=per_test_output, env=env).wait()
shutil.rmtree(tmp_dir, ignore_errors=True)
except:
LOGGER.exception("Got exception while running %s with %s", test_name, pyspark_python)
LOGGER.exception(
"Got exception while running %s with %s", str_test_name, pyspark_python)
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(1)
Expand All @@ -124,7 +125,8 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
except:
LOGGER.exception("Got an exception while trying to print failed test output")
finally:
print_red("\nHad test failures in %s with %s; see logs." % (test_name, pyspark_python))
print_red("\nHad test failures in %s with %s; see logs." % (
str_test_name, pyspark_python))
# Here, we use os._exit() instead of sys.exit() in order to force Python to exit even if
# this code is invoked from a thread other than the main thread.
os._exit(-1)
Expand All @@ -140,7 +142,7 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
decoded_lines))
skipped_counts = len(skipped_tests)
if skipped_counts > 0:
key = (pyspark_python, test_name)
key = (pyspark_python, str_test_name)
SKIPPED_TESTS[key] = skipped_tests
per_test_output.close()
except:
Expand All @@ -152,11 +154,11 @@ def run_individual_python_test(target_dir, test_name, pyspark_python):
os._exit(-1)
if skipped_counts != 0:
LOGGER.info(
"Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python, test_name,
duration, skipped_counts)
"Finished test(%s): %s (%is) ... %s tests were skipped", pyspark_python,
str_test_name, duration, skipped_counts)
else:
LOGGER.info(
"Finished test(%s): %s (%is)", pyspark_python, test_name, duration)
"Finished test(%s): %s (%is)", pyspark_python, str_test_name, duration)


def get_default_python_executables():
Expand Down Expand Up @@ -190,6 +192,20 @@ def parse_opts():
help="Enable additional debug logging"
)

group = OptionGroup(parser, "Developer Options")
group.add_option(
"--testnames", type="string",
default=None,
help=(
"A comma-separated list of specific modules, classes and functions of doctest "
"or unittest to test. "
"For example, 'pyspark.sql.foo' to run the module as unittests or doctests, "
"'pyspark.sql.tests FooTests' to run the specific class of unittests, "
"'pyspark.sql.tests FooTests.test_foo' to run the specific unittest in the class. "
"'--modules' option is ignored if they are given.")
)
parser.add_option_group(group)

(opts, args) = parser.parse_args()
if args:
parser.error("Unsupported arguments: %s" % ' '.join(args))
Expand All @@ -213,25 +229,31 @@ def _check_coverage(python_exec):

def main():
opts = parse_opts()
if (opts.verbose):
if opts.verbose:
log_level = logging.DEBUG
else:
log_level = logging.INFO
should_test_modules = opts.testnames is None
logging.basicConfig(stream=sys.stdout, level=log_level, format="%(message)s")
LOGGER.info("Running PySpark tests. Output is in %s", LOG_FILE)
if os.path.exists(LOG_FILE):
os.remove(LOG_FILE)
python_execs = opts.python_executables.split(',')
modules_to_test = []
for module_name in opts.modules.split(','):
if module_name in python_modules:
modules_to_test.append(python_modules[module_name])
else:
print("Error: unrecognized module '%s'. Supported modules: %s" %
(module_name, ", ".join(python_modules)))
sys.exit(-1)
LOGGER.info("Will test against the following Python executables: %s", python_execs)
LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])

if should_test_modules:
modules_to_test = []
for module_name in opts.modules.split(','):
if module_name in python_modules:
modules_to_test.append(python_modules[module_name])
else:
print("Error: unrecognized module '%s'. Supported modules: %s" %
(module_name, ", ".join(python_modules)))
sys.exit(-1)
LOGGER.info("Will test the following Python modules: %s", [x.name for x in modules_to_test])
else:
testnames_to_test = opts.testnames.split(',')
LOGGER.info("Will test the following Python tests: %s", testnames_to_test)

task_queue = Queue.PriorityQueue()
for python_exec in python_execs:
Expand All @@ -246,16 +268,20 @@ def main():
LOGGER.debug("%s python_implementation is %s", python_exec, python_implementation)
LOGGER.debug("%s version is: %s", python_exec, subprocess_check_output(
[python_exec, "--version"], stderr=subprocess.STDOUT, universal_newlines=True).strip())
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
for test_goal in module.python_test_goals:
heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
priority = 0
else:
priority = 100
task_queue.put((priority, (python_exec, test_goal)))
if should_test_modules:
for module in modules_to_test:
if python_implementation not in module.blacklisted_python_implementations:
for test_goal in module.python_test_goals:
heavy_tests = ['pyspark.streaming.tests', 'pyspark.mllib.tests',
'pyspark.tests', 'pyspark.sql.tests', 'pyspark.ml.tests']
if any(map(lambda prefix: test_goal.startswith(prefix), heavy_tests)):
priority = 0
else:
priority = 100
task_queue.put((priority, (python_exec, (test_goal, ))))
else:
for test_goal in testnames_to_test:
task_queue.put((0, (python_exec, tuple(test_goal.split()))))

# Create the target directory before starting tasks to avoid races.
target_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), 'target'))
Expand Down

0 comments on commit 44c622b

Please sign in to comment.