Skip to content

Commit

Permalink
Improved plot rendering (#8580)
Browse files Browse the repository at this point in the history
* changes

* add changeset

* add changeset

* add changeset

* changes

* changes

* restore altair

* changes

* changes

* changes

* changes

* changes

* changes

* Update twenty-jokes-argue.md

* changes

* chanegs

* changes

* changes

* changes

* changes

---------

Co-authored-by: Ali Abid <aliabid94@gmail.com>
Co-authored-by: gradio-pr-bot <gradio-pr-bot@users.noreply.github.com>
  • Loading branch information
3 people authored Jun 20, 2024
1 parent 1e61644 commit 797621b
Show file tree
Hide file tree
Showing 21 changed files with 562 additions and 352 deletions.
7 changes: 7 additions & 0 deletions .changeset/twenty-jokes-argue.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
---
"@gradio/plot": patch
"gradio": patch
---

feat:Improved plot rendering to thematically match
highlight:Expect visual changes in gr.Plot, gr.BarPlot, gr.LinePlot, gr.ScatterPlot, including changes to color and width sizing.
18 changes: 9 additions & 9 deletions demo/native_plots/bar_plot_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ def bar_plot_fn(display):


with gr.Blocks() as bar_plot:
with gr.Row():
with gr.Column():
display = gr.Dropdown(
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
value="simple",
label="Type of Bar Plot"
)
with gr.Column():
plot = gr.BarPlot(show_label=False, show_actions_button=True)
display = gr.Dropdown(
choices=["simple", "stacked", "grouped", "simple-horizontal", "stacked-horizontal", "grouped-horizontal"],
value="simple",
label="Type of Bar Plot"
)
plot = gr.BarPlot(show_label=False)
display.change(bar_plot_fn, inputs=display, outputs=plot)
bar_plot.load(fn=bar_plot_fn, inputs=display, outputs=plot)

if __name__ == "__main__":
bar_plot.launch()
21 changes: 5 additions & 16 deletions demo/native_plots/line_plot_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def line_plot_fn(dataset):
overlay_point=False,
title="Stock Prices",
stroke_dash_legend_title=None,
height=300,
width=500
)
elif dataset == "climate":
return gr.LinePlot(
Expand All @@ -40,8 +38,6 @@ def line_plot_fn(dataset):
overlay_point=False,
title="Climate",
stroke_dash_legend_title=None,
height=300,
width=500
)
elif dataset == "seattle_weather":
return gr.LinePlot(
Expand All @@ -56,8 +52,6 @@ def line_plot_fn(dataset):
overlay_point=True,
title="Seattle Weather",
stroke_dash_legend_title=None,
height=300,
width=500
)
elif dataset == "gapminder":
return gr.LinePlot(
Expand All @@ -72,20 +66,15 @@ def line_plot_fn(dataset):
overlay_point=False,
title="Life expectancy for countries",
stroke_dash_legend_title="Country Cluster",
height=300,
width=500
)


with gr.Blocks() as line_plot:
with gr.Row():
with gr.Column():
dataset = gr.Dropdown(
choices=["stocks", "climate", "seattle_weather", "gapminder"],
value="stocks",
)
with gr.Column():
plot = gr.LinePlot()
dataset = gr.Dropdown(
choices=["stocks", "climate", "seattle_weather", "gapminder"],
value="stocks",
)
plot = gr.LinePlot()
dataset.change(line_plot_fn, inputs=dataset, outputs=plot)
line_plot.load(fn=line_plot_fn, inputs=dataset, outputs=plot)

Expand Down
15 changes: 7 additions & 8 deletions demo/native_plots/scatter_plot_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ def scatter_plot_fn(dataset):
value=iris,
x="petalWidth",
y="petalLength",
color="species",
color=None,
title="Iris Dataset",
color_legend_title="Species",
x_title="Petal Width",
y_title="Petal Length",
tooltip=["petalWidth", "petalLength", "species"],
caption="",
height=600,
width=600,
)
else:
return gr.ScatterPlot(
Expand All @@ -29,17 +30,15 @@ def scatter_plot_fn(dataset):
tooltip="Name",
title="Car Data",
y_title="Miles per Gallon",
color_legend_title="Origin of Car",
caption="MPG vs Horsepower of various cars",
height=None,
width=None,
)


with gr.Blocks() as scatter_plot:
with gr.Row():
with gr.Column():
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
with gr.Column():
plot = gr.ScatterPlot(show_label=False)
dataset = gr.Dropdown(choices=["cars", "iris"], value="cars")
plot = gr.ScatterPlot(show_label=False)
dataset.change(scatter_plot_fn, inputs=dataset, outputs=plot)
scatter_plot.load(fn=scatter_plot_fn, inputs=dataset, outputs=plot)

Expand Down
34 changes: 22 additions & 12 deletions gradio/components/bar_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal

from gradio_client.documentation import document
Expand Down Expand Up @@ -52,8 +53,8 @@ def __init__(
"none",
]
| None = None,
height: int | str | None = None,
width: int | str | None = None,
height: int | None = None,
width: int | None = None,
y_lim: list[int] | None = None,
caption: str | None = None,
interactive: bool | None = True,
Expand Down Expand Up @@ -88,8 +89,8 @@ def __init__(
color_legend_title: The title given to the color legend. By default, uses the value of color parameter.
group_title: The label displayed on top of the subplot columns (or rows if vertical=True). Use an empty string to omit.
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
height: The height of the plot in pixels.
width: The width of the plot in pixels. If None, expands to fit.
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
interactive: Whether users should be able to interact with the plot by panning or zooming with their mouse or trackpad.
Expand Down Expand Up @@ -122,10 +123,22 @@ def __init__(
self.y_lim = y_lim
self.caption = caption
self.interactive_chart = interactive
if isinstance(width, str):
width = None
warnings.warn(
"Width should be an integer, not a string. Setting width to None."
)
if isinstance(height, str):
warnings.warn(
"Height should be an integer, not a string. Setting height to None."
)
height = None
self.width = width
self.height = height
self.sort = sort
self.show_actions_button = show_actions_button
if label is None and show_label is None:
show_label = False
super().__init__(
value=value,
label=label,
Expand Down Expand Up @@ -172,8 +185,8 @@ def create_plot(
"none",
]
| None = None,
height: int | str | None = None,
width: int | str | None = None,
height: int | None = None,
width: int | None = None,
y_lim: list[int] | None = None,
interactive: bool | None = True,
sort: Literal["x", "y", "-x", "-y"] | None = None,
Expand All @@ -182,11 +195,7 @@ def create_plot(
import altair as alt

interactive = True if interactive is None else interactive
orientation = (
{"field": group, "title": group_title if group_title is not None else group}
if group
else {}
)
orientation = {"field": group, "title": group_title} if group else {}

x_title = x_title or x
y_title = y_title or y
Expand Down Expand Up @@ -234,14 +243,15 @@ def create_plot(
properties["width"] = width

if color:
color_legend_position = color_legend_position or "bottom"
domain = value[color].unique().tolist()
range_ = list(range(len(domain)))
encodings["color"] = {
"field": color,
"type": "nominal",
"scale": {"domain": domain, "range": range_},
"legend": AltairPlot.create_legend(
position=color_legend_position, title=color_legend_title or color
position=color_legend_position, title=color_legend_title
),
}

Expand Down
24 changes: 19 additions & 5 deletions gradio/components/line_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal

from gradio_client.documentation import document
Expand Down Expand Up @@ -64,8 +65,8 @@ def __init__(
"none",
]
| None = None,
height: int | str | None = None,
width: int | str | None = None,
height: int | None = None,
width: int | None = None,
x_lim: list[int] | None = None,
y_lim: list[int] | None = None,
caption: str | None = None,
Expand Down Expand Up @@ -101,8 +102,8 @@ def __init__(
stroke_dash_legend_title: The title given to the stroke_dash legend. By default, uses the value of the stroke_dash parameter.
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
stroke_dash_legend_position: The position of the stoke_dash legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
height: The height of the plot in pixels.
width: The width of the plot in pixels. If None, expands to fit.
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max].
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
Expand Down Expand Up @@ -136,9 +137,21 @@ def __init__(
self.y_lim = y_lim
self.caption = caption
self.interactive_chart = interactive
if isinstance(width, str):
width = None
warnings.warn(
"Width should be an integer, not a string. Setting width to None."
)
if isinstance(height, str):
warnings.warn(
"Height should be an integer, not a string. Setting height to None."
)
height = None
self.width = width
self.height = height
self.show_actions_button = show_actions_button
if label is None and show_label is None:
show_label = False
super().__init__(
value=value,
label=label,
Expand Down Expand Up @@ -234,14 +247,15 @@ def create_plot(
properties["width"] = width

if color:
color_legend_position = color_legend_position or "bottom"
domain = value[color].unique().tolist()
range_ = list(range(len(domain)))
encodings["color"] = {
"field": color,
"type": "nominal",
"scale": {"domain": domain, "range": range_},
"legend": AltairPlot.create_legend(
position=color_legend_position, title=color_legend_title or color
position=color_legend_position, title=color_legend_title
),
}

Expand Down
28 changes: 21 additions & 7 deletions gradio/components/scatter_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal

from gradio_client.documentation import document
Expand Down Expand Up @@ -77,8 +78,8 @@ def __init__(
"none",
]
| None = None,
height: int | str | None = None,
width: int | str | None = None,
height: int | None = None,
width: int | None = None,
x_lim: list[int | float] | None = None,
y_lim: list[int | float] | None = None,
caption: str | None = None,
Expand Down Expand Up @@ -116,8 +117,8 @@ def __init__(
color_legend_position: The position of the color legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
size_legend_position: The position of the size legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
shape_legend_position: The position of the shape legend. If the string value 'none' is passed, this legend is omitted. For other valid position values see: https://vega.github.io/vega/docs/legends/#orientation.
height: The height of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
width: The width of the plot, specified in pixels if a number is passed, or in CSS units if a string is passed.
height: The height of the plot in pixels.
width: The width of the plot in pixels. If None, expands to fit.
x_lim: A tuple or list containing the limits for the x-axis, specified as [x_min, x_max].
y_lim: A tuple of list containing the limits for the y-axis, specified as [y_min, y_max].
caption: The (optional) caption to display below the plot.
Expand Down Expand Up @@ -151,11 +152,23 @@ def __init__(
self.shape_legend_position = shape_legend_position
self.caption = caption
self.interactive_chart = interactive
if isinstance(width, str):
width = None
warnings.warn(
"Width should be an integer, not a string. Setting width to None."
)
if isinstance(height, str):
warnings.warn(
"Height should be an integer, not a string. Setting height to None."
)
height = None
self.width = width
self.height = height
self.x_lim = x_lim
self.y_lim = y_lim
self.show_actions_button = show_actions_button
if label is None and show_label is None:
show_label = False
super().__init__(
value=value,
label=label,
Expand Down Expand Up @@ -273,11 +286,12 @@ def create_plot(
range_ = list(range(len(domain)))
type_ = "nominal"

color_legend_position = color_legend_position or "bottom"
encodings["color"] = {
"field": color,
"type": type_,
"legend": AltairPlot.create_legend(
position=color_legend_position, title=color_legend_title or color
position=color_legend_position, title=color_legend_title
),
"scale": {"domain": domain, "range": range_},
}
Expand All @@ -288,15 +302,15 @@ def create_plot(
"field": size,
"type": "quantitative" if is_numeric_dtype(value[size]) else "nominal",
"legend": AltairPlot.create_legend(
position=size_legend_position, title=size_legend_title or size
position=size_legend_position, title=size_legend_title
),
}
if shape:
encodings["shape"] = {
"field": shape,
"type": "quantitative" if is_numeric_dtype(value[shape]) else "nominal",
"legend": AltairPlot.create_legend(
position=shape_legend_position, title=shape_legend_title or shape
position=shape_legend_position, title=shape_legend_title
),
}
chart = (
Expand Down
Loading

0 comments on commit 797621b

Please sign in to comment.