From c16c392c6b41b0922e9fc95dc54125ba926bdb13 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Mon, 22 Sep 2014 11:58:37 -0700 Subject: [PATCH] put addPyFile in front of sys.path --- python/pyspark/context.py | 11 +++++------ python/pyspark/worker.py | 11 +++++++++-- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 064a24bff539c..8e7b00469e246 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, SparkFiles._sc = self root_dir = SparkFiles.getRootDirectory() - sys.path.append(root_dir) + sys.path.insert(1, root_dir) # Deploy any code dependencies specified in the constructor self._python_includes = list() @@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize, for path in self._conf.get("spark.submit.pyFiles", "").split(","): if path != "": (dirname, filename) = os.path.split(path) - self._python_includes.append(filename) - sys.path.append(path) - if dirname not in sys.path: - sys.path.append(dirname) + if filename.lower().endswith("zip") or filename.lower().endswith("egg"): + self._python_includes.append(filename) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) # Create a temporary directory inside spark.local.dir: local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf()) @@ -667,7 +666,7 @@ def addPyFile(self, path): if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'): self._python_includes.append(filename) # for tests in local mode - sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename)) + sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename)) def setCheckpointDir(self, dirName): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d6c06e2dbef62..c1f6e3e4a1f40 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish): write_long(1000 * finish, outfile) +def add_path(path): + # worker can be used, so donot add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + def main(infile, outfile): try: boot_time = time.time() @@ -61,11 +68,11 @@ def main(infile, outfile): SparkFiles._is_running_on_worker = True # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - sys.path.append(spark_files_dir) # *.py files that were added will be copied here + add_path(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): filename = utf8_deserializer.loads(infile) - sys.path.append(os.path.join(spark_files_dir, filename)) + add_path(os.path.join(spark_files_dir, filename)) # fetch names and values of broadcast variables num_broadcast_variables = read_int(infile)