Skip to content

Commit

Permalink
[SPARK-49928][PYTHON][TESTS] Refactor plot-related unit tests
Browse files Browse the repository at this point in the history
### What changes were proposed in this pull request?
Refactor plot-related unit tests.

### Why are the changes needed?
Different plots have different key attributes of the resulting figure to test against. The refactor makes the comparison easier.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
Test changes.

### Was this patch authored or co-authored using generative AI tooling?
No.

Closes #48415 from xinrong-meng/plot_test.

Authored-by: Xinrong Meng <xinrong@apache.org>
Signed-off-by: Xinrong Meng <xinrong@apache.org>
  • Loading branch information
xinrong-meng committed Oct 14, 2024
1 parent 1abfd49 commit 1aae160
Showing 1 changed file with 192 additions and 50 deletions.
242 changes: 192 additions & 50 deletions python/pyspark/sql/tests/plot/test_frame_plot_plotly.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,79 +48,174 @@ def sdf3(self):
columns = ["sales", "signups", "visits", "date"]
return self.spark.createDataFrame(data, columns)

def _check_fig_data(self, kind, fig_data, expected_x, expected_y, expected_name=""):
if kind == "line":
self.assertEqual(fig_data["mode"], "lines")
self.assertEqual(fig_data["type"], "scatter")
elif kind == "bar":
self.assertEqual(fig_data["type"], "bar")
elif kind == "barh":
self.assertEqual(fig_data["type"], "bar")
self.assertEqual(fig_data["orientation"], "h")
elif kind == "scatter":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["mode"], "markers")
elif kind == "area":
self.assertEqual(fig_data["type"], "scatter")
self.assertEqual(fig_data["orientation"], "v")
self.assertEqual(fig_data["mode"], "lines")
elif kind == "pie":
self.assertEqual(fig_data["type"], "pie")
self.assertEqual(list(fig_data["labels"]), expected_x)
self.assertEqual(list(fig_data["values"]), expected_y)
return

self.assertEqual(fig_data["xaxis"], "x")
self.assertEqual(list(fig_data["x"]), expected_x)
self.assertEqual(fig_data["yaxis"], "y")
self.assertEqual(list(fig_data["y"]), expected_y)
self.assertEqual(fig_data["name"], expected_name)
def _check_fig_data(self, fig_data, **kwargs):
for key, expected_value in kwargs.items():
if key in ["x", "y", "labels", "values"]:
converted_values = [v.item() if hasattr(v, "item") else v for v in fig_data[key]]
self.assertEqual(converted_values, expected_value)
else:
self.assertEqual(fig_data[key], expected_value)

def test_line_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="line", x="category", y="int_val")
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
expected_fig_data = {
"mode": "lines",
"name": "",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

# multiple columns as vertical axis
fig = self.sdf.plot.line(x="category", y=["int_val", "float_val"])
self._check_fig_data("line", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("line", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
expected_fig_data = {
"mode": "lines",
"name": "int_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"mode": "lines",
"name": "float_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)

def test_bar_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="bar", x="category", y="int_val")
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
expected_fig_data = {
"name": "",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

# multiple columns as vertical axis
fig = self.sdf.plot.bar(x="category", y=["int_val", "float_val"])
self._check_fig_data("bar", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("bar", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
expected_fig_data = {
"name": "int_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "v",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)

def test_barh_plot(self):
# single column as vertical axis
fig = self.sdf.plot(kind="barh", x="category", y="int_val")
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20])
expected_fig_data = {
"name": "",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

# multiple columns as vertical axis
fig = self.sdf.plot.barh(x="category", y=["int_val", "float_val"])
self._check_fig_data("barh", fig["data"][0], ["A", "B", "C"], [10, 30, 20], "int_val")
self._check_fig_data("barh", fig["data"][1], ["A", "B", "C"], [1.5, 2.5, 3.5], "float_val")
expected_fig_data = {
"name": "int_val",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "h",
"x": ["A", "B", "C"],
"xaxis": "x",
"y": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)

# multiple columns as horizontal axis
fig = self.sdf.plot.barh(x=["int_val", "float_val"], y="category")
self._check_fig_data("barh", fig["data"][0], [10, 30, 20], ["A", "B", "C"], "int_val")
self._check_fig_data("barh", fig["data"][1], [1.5, 2.5, 3.5], ["A", "B", "C"], "float_val")
expected_fig_data = {
"name": "int_val",
"orientation": "h",
"y": ["A", "B", "C"],
"xaxis": "x",
"x": [10, 30, 20],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "float_val",
"orientation": "h",
"y": ["A", "B", "C"],
"xaxis": "x",
"x": [1.5, 2.5, 3.5],
"yaxis": "y",
"type": "bar",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)

def test_scatter_plot(self):
fig = self.sdf2.plot(kind="scatter", x="length", y="width")
self._check_fig_data(
"scatter", fig["data"][0], [5.1, 4.9, 7.0, 6.4, 5.9], [3.5, 3.0, 3.2, 3.2, 3.0]
)
expected_fig_data = {
"name": "",
"orientation": "v",
"x": [5.1, 4.9, 7.0, 6.4, 5.9],
"xaxis": "x",
"y": [3.5, 3.0, 3.2, 3.2, 3.0],
"yaxis": "y",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "",
"orientation": "v",
"y": [5.1, 4.9, 7.0, 6.4, 5.9],
"xaxis": "x",
"x": [3.5, 3.0, 3.2, 3.2, 3.0],
"yaxis": "y",
"type": "scatter",
}
fig = self.sdf2.plot.scatter(x="width", y="length")
self._check_fig_data(
"scatter", fig["data"][0], [3.5, 3.0, 3.2, 3.2, 3.0], [5.1, 4.9, 7.0, 6.4, 5.9]
)
self._check_fig_data(fig["data"][0], **expected_fig_data)

def test_area_plot(self):
# single column as vertical axis
Expand All @@ -131,13 +226,53 @@ def test_area_plot(self):
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9])
expected_fig_data = {
"name": "",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [3, 2, 3, 9],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

# multiple columns as vertical axis
fig = self.sdf3.plot.area(x="date", y=["sales", "signups", "visits"])
self._check_fig_data("area", fig["data"][0], expected_x, [3, 2, 3, 9], "sales")
self._check_fig_data("area", fig["data"][1], expected_x, [5, 5, 6, 12], "signups")
self._check_fig_data("area", fig["data"][2], expected_x, [20, 42, 28, 62], "visits")
expected_fig_data = {
"name": "sales",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [3, 2, 3, 9],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)
expected_fig_data = {
"name": "signups",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [5, 5, 6, 12],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][1], **expected_fig_data)
expected_fig_data = {
"name": "visits",
"orientation": "v",
"x": expected_x,
"xaxis": "x",
"y": [20, 42, 28, 62],
"yaxis": "y",
"mode": "lines",
"type": "scatter",
}
self._check_fig_data(fig["data"][2], **expected_fig_data)

def test_pie_plot(self):
fig = self.sdf3.plot(kind="pie", x="date", y="sales")
Expand All @@ -147,11 +282,18 @@ def test_pie_plot(self):
datetime(2018, 3, 31, 0, 0),
datetime(2018, 4, 30, 0, 0),
]
self._check_fig_data("pie", fig["data"][0], expected_x, [3, 2, 3, 9])
expected_fig_data = {
"name": "",
"labels": expected_x,
"values": [3, 2, 3, 9],
"type": "pie",
}
self._check_fig_data(fig["data"][0], **expected_fig_data)

# y is not a numerical column
with self.assertRaises(PySparkTypeError) as pe:
self.sdf.plot.pie(x="int_val", y="category")

self.check_error(
exception=pe.exception,
errorClass="PLOT_NOT_NUMERIC_COLUMN",
Expand Down

0 comments on commit 1aae160

Please sign in to comment.