Skip to content

Commit

Permalink
put addPyFile in front of sys.path
Browse files Browse the repository at this point in the history
  • Loading branch information
davies committed Sep 22, 2014
1 parent 56dae30 commit c16c392
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
11 changes: 5 additions & 6 deletions python/pyspark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,

SparkFiles._sc = self
root_dir = SparkFiles.getRootDirectory()
sys.path.append(root_dir)
sys.path.insert(1, root_dir)

# Deploy any code dependencies specified in the constructor
self._python_includes = list()
Expand All @@ -183,10 +183,9 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
for path in self._conf.get("spark.submit.pyFiles", "").split(","):
if path != "":
(dirname, filename) = os.path.split(path)
self._python_includes.append(filename)
sys.path.append(path)
if dirname not in sys.path:
sys.path.append(dirname)
if filename.lower().endswith("zip") or filename.lower().endswith("egg"):
self._python_includes.append(filename)
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
Expand Down Expand Up @@ -667,7 +666,7 @@ def addPyFile(self, path):
if filename.endswith('.zip') or filename.endswith('.ZIP') or filename.endswith('.egg'):
self._python_includes.append(filename)
# for tests in local mode
sys.path.append(os.path.join(SparkFiles.getRootDirectory(), filename))
sys.path.insert(1, os.path.join(SparkFiles.getRootDirectory(), filename))

def setCheckpointDir(self, dirName):
"""
Expand Down
11 changes: 9 additions & 2 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@ def report_times(outfile, boot, init, finish):
write_long(1000 * finish, outfile)


def add_path(path):
# worker can be used, so donot add path multiple times
if path not in sys.path:
# overwrite system packages
sys.path.insert(1, path)


def main(infile, outfile):
try:
boot_time = time.time()
Expand All @@ -61,11 +68,11 @@ def main(infile, outfile):
SparkFiles._is_running_on_worker = True

# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
add_path(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
filename = utf8_deserializer.loads(infile)
sys.path.append(os.path.join(spark_files_dir, filename))
add_path(os.path.join(spark_files_dir, filename))

# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
Expand Down

0 comments on commit c16c392

Please sign in to comment.