Skip to content

Commit

Permalink
Merge branch 'px_special_inputs' into wide_form2
Browse files Browse the repository at this point in the history
  • Loading branch information
nicolaskruchten committed Apr 3, 2020
2 parents 236cd2c + 918b87b commit b50cd08
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 5 deletions.
7 changes: 7 additions & 0 deletions packages/python/plotly/plotly/express/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@
get_trendline_results,
)

from ._special_inputs import ( # noqa: F401
IdentityMap,
Constant,
)

from . import data, colors # noqa: F401

__all__ = [
Expand Down Expand Up @@ -95,4 +100,6 @@
"colors",
"set_mapbox_access_token",
"get_trendline_results",
"IdentityMap",
"Constant",
]
29 changes: 24 additions & 5 deletions packages/python/plotly/plotly/express/_core.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import plotly.graph_objs as go
import plotly.io as pio
from collections import namedtuple, OrderedDict
from ._special_inputs import IdentityMap, Constant

from _plotly_utils.basevalidators import ColorscaleValidator
from .colors import qualitative, sequential
Expand Down Expand Up @@ -41,6 +42,7 @@ def __init__(self):
defaults = PxDefaults()
del PxDefaults


MAPBOX_TOKEN = None


Expand Down Expand Up @@ -141,11 +143,15 @@ def make_mapping(args, variable):
if variable == "dash":
arg_name = "line_dash"
vprefix = "line_dash"
if args[vprefix + "_map"] == "identity":
val_map = IdentityMap()
else:
val_map = args[vprefix + "_map"].copy()
return Mapping(
show_in_trace_name=True,
variable=variable,
grouper=args[arg_name],
val_map=args[vprefix + "_map"].copy(),
val_map=val_map,
sequence=args[vprefix + "_sequence"],
updater=lambda trace, v: trace.update({parent: {variable: v}}),
facet=None,
Expand Down Expand Up @@ -937,6 +943,8 @@ def build_dataframe(args, attrables, array_attrables, constructor):
else:
df_output[df_input.columns] = df_input[df_input.columns]

constants = dict()

# Loop over possible arguments
for field_name in attrables:
# Massaging variables
Expand Down Expand Up @@ -968,8 +976,15 @@ def build_dataframe(args, attrables, array_attrables, constructor):
"pandas MultiIndex is not supported by plotly express "
"at the moment." % field
)
# ----------------- argument is a constant ----------------------
if isinstance(argument, Constant):
col_name = _check_name_not_reserved(
str(argument.label) if argument.label is not None else field,
reserved_names,
)
constants[col_name] = argument.value
# ----------------- argument is a col name ----------------------
if isinstance(argument, str) or isinstance(
elif isinstance(argument, str) or isinstance(
argument, int
): # just a column name given as str or int
if not df_provided:
Expand Down Expand Up @@ -1073,6 +1088,9 @@ def build_dataframe(args, attrables, array_attrables, constructor):
args["x" if orient_v else "y"] = "value"
args["color"] = args["color"] or "variable"

for col_name in constants:
df_output[col_name] = constants[col_name]

args["data_frame"] = df_output
return args

Expand Down Expand Up @@ -1491,9 +1509,10 @@ def make_figure(args, constructor, trace_patch=None, layout_patch=None):
for col, val, m in zip(grouper, group_name, grouped_mappings):
if col != one_group:
key = get_label(args, col)
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if not isinstance(m.val_map, IdentityMap):
mapping_labels[key] = str(val)
if m.show_in_trace_name:
trace_name_labels[key] = str(val)
if m.variable == "animation_frame":
frame_name = val
trace_name = ", ".join(trace_name_labels.values())
Expand Down
29 changes: 29 additions & 0 deletions packages/python/plotly/plotly/express/_special_inputs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
class IdentityMap(object):
"""
`dict`-like object which acts as if the value for any key is the key itself. Objects
of this class can be passed in to arguments like `color_discrete_map` to
use the provided data values as colors, rather than mapping them to colors cycled
from `color_discrete_sequence`. This works for any `_map` argument to Plotly Express
functions, such as `line_dash_map` and `symbol_map`.
"""

def __getitem__(self, key):
return key

def __contains__(self, key):
return True

def copy(self):
return self


class Constant(object):
"""
Objects of this class can be passed to Plotly Express functions that expect column
identifiers or list-like objects to indicate that this attribute should take on a
constant value. An optional label can be provided.
"""

def __init__(self, value, label=None):
self.value = value
self.label = label
Original file line number Diff line number Diff line change
Expand Up @@ -323,3 +323,61 @@ def test_size_column():
df = px.data.tips()
fig = px.scatter(df, x=df["size"], y=df.tip)
assert fig.data[0].hovertemplate == "size=%{x}<br>tip=%{y}<extra></extra>"


def test_identity_map():
fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=["red", "blue"],
color_discrete_map="identity",
)
assert fig.data[0].marker.color == "red"
assert fig.data[1].marker.color == "blue"
assert "color=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"


def test_constants():
fig = px.scatter(x=px.Constant(1), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" in fig.data[0].hovertemplate

fig = px.scatter(x=px.Constant(1, label="time"), y=[1, 2])
assert fig.data[0].x[0] == 1
assert fig.data[0].x[1] == 1
assert "x=" not in fig.data[0].hovertemplate
assert "time=" in fig.data[0].hovertemplate

fig = px.scatter(
x=[1, 2],
y=[1, 2],
symbol=["a", "b"],
color=px.Constant("red", label="the_identity_label"),
hover_data=[px.Constant("data", label="the_data")],
color_discrete_map=px.IdentityMap(),
)
assert fig.data[0].marker.color == "red"
assert fig.data[0].customdata[0][0] == "data"
assert fig.data[1].marker.color == "red"
assert "color=" not in fig.data[0].hovertemplate
assert "the_identity_label=" not in fig.data[0].hovertemplate
assert "symbol=" in fig.data[0].hovertemplate
assert "the_data=" in fig.data[0].hovertemplate
assert fig.layout.legend.title.text == "symbol"

0 comments on commit b50cd08

Please sign in to comment.