Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-25591][PySpark][SQL][BRANCH-2.3] Avoid overwriting deserialized accumulator #23432

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@

def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
# If this certain accumulator was deserialized, don't overwrite it.
if aid in _accumulatorRegistry:
return _accumulatorRegistry[aid]
else:
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum


class Accumulator(object):
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,31 @@ def test_create_dateframe_from_pandas_with_dst(self):
os.environ['TZ'] = orig_env_tz
time.tzset()

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)

test_accum = self.sc.accumulator(0)

def first_udf(x):
test_accum.add(1)
return x

def second_udf(x):
test_accum.add(100)
return x

func_udf = udf(first_udf, IntegerType())
func_udf2 = udf(second_udf, IntegerType())
data = data.withColumn("out1", func_udf(data["a"]))
data = data.withColumn("out2", func_udf2(data["b"]))
data.collect()
self.assertEqual(test_accum.value, 101)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down