Skip to content

Commit

Permalink
Avoid overwriting deserialized accumulator.
Browse files Browse the repository at this point in the history
  • Loading branch information
viirya committed Jan 3, 2019
1 parent c3d759f commit 30f4724
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
12 changes: 8 additions & 4 deletions python/pyspark/accumulators.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,14 @@

def _deserialize_accumulator(aid, zero_value, accum_param):
from pyspark.accumulators import _accumulatorRegistry
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum
# If this certain accumulator was deserialized, don't overwrite it.
if aid in _accumulatorRegistry:
return _accumulatorRegistry[aid]
else:
accum = Accumulator(aid, zero_value, accum_param)
accum._deserialized = True
_accumulatorRegistry[aid] = accum
return accum


class Accumulator(object):
Expand Down
25 changes: 25 additions & 0 deletions python/pyspark/sql/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -2932,6 +2932,31 @@ def test_create_dateframe_from_pandas_with_dst(self):
os.environ['TZ'] = orig_env_tz
time.tzset()

# SPARK-25591
def test_same_accumulator_in_udfs(self):
from pyspark.sql.functions import udf

data_schema = StructType([StructField("a", IntegerType(), True),
StructField("b", IntegerType(), True)])
data = self.spark.createDataFrame([[1, 2]], schema=data_schema)

test_accum = self.sc.accumulator(0)

def first_udf(x):
test_accum.add(1)
return x

def second_udf(x):
test_accum.add(100)
return x

func_udf = udf(first_udf, IntegerType())
func_udf2 = udf(second_udf, IntegerType())
data = data.withColumn("out1", func_udf(data["a"]))
data = data.withColumn("out2", func_udf2(data["b"]))
data.collect()
self.assertEqual(test_accum.value, 101)


class HiveSparkSubmitTests(SparkSubmitTests):

Expand Down

0 comments on commit 30f4724

Please sign in to comment.