diff --git a/luigi/contrib/spark.py b/luigi/contrib/spark.py index 668b7ba19a..84068043f1 100644 --- a/luigi/contrib/spark.py +++ b/luigi/contrib/spark.py @@ -292,6 +292,10 @@ def files(self): if self.deploy_mode == "cluster": return [self.run_pickle] + @property + def pickle_protocol(self): + return configuration.get_config().getint('spark', 'pickle-protocol', pickle.DEFAULT_PROTOCOL) + def setup(self, conf): """ Called by the pyspark_runner with a SparkConf instance that will be used to instantiate the SparkContext @@ -335,12 +339,12 @@ def run(self): def _dump(self, fd): with self.no_unpicklable_properties(): if self.__module__ == '__main__': - d = pickle.dumps(self) + d = pickle.dumps(self, protocol=self.pickle_protocol) module_name = os.path.basename(sys.argv[0]).rsplit('.', 1)[0] d = d.replace(b'c__main__', b'c' + module_name.encode('ascii')) fd.write(d) else: - pickle.dump(self, fd) + pickle.dump(self, fd, protocol=self.pickle_protocol) def _setup_packages(self, sc): """