Skip to content

Commit

Permalink
Stabilize the vega-lite spec for altair-based charts (streamlit#8628)
Browse files Browse the repository at this point in the history
## Describe your changes

Altair uses global counters to name unnamed parameter & view/subchart
names (see this issue
[here](vega/altair#3416)). This leads to
different vega-lite chart specs across reruns for the same charts since
the auto-named parameters will change the counter on every new chart
creation. Some of our features like forward message cache & the upcoming
chart selections will require a stable identity for the chart element.

Therefore, we are applying a (temporary) slightly hacky solution to
replace these generated names with names that are stable across reruns.
The relevant naming patterns are: `view_<number>` and `param_<number>`.
Eventually, we can replace this hack if it gets addressed in Altair (see
issue [here](vega/altair#3416))

I performance-tested the `_stabilize_vega_json_spec` method with a
bigger concatenated chart spec by running it 10k times. The total
runtime was ~0.3s, about ~0.03 milliseconds per execution. There are a
couple of ways to make this a lot more performant, but I think it would
be fine to put that into a follow-up PR after the selections release.

## Testing Plan

- Added unit tests.

---

**Contribution License Agreement**

By submitting this pull request you agree that all contributions to this
project are made under the Apache 2.0 license.
  • Loading branch information
lukasmasuch authored and benjamin-awd committed Sep 29, 2024
1 parent 3583f26 commit 9f28b0a
Show file tree
Hide file tree
Showing 10 changed files with 368 additions and 13 deletions.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
126 changes: 117 additions & 9 deletions lib/streamlit/elements/vega_charts.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

from __future__ import annotations

import hashlib
import json
import re
from contextlib import nullcontext
from typing import TYPE_CHECKING, Any, Final, Literal, Sequence, cast

Expand All @@ -32,6 +34,7 @@
ArrowVegaLiteChart as ArrowVegaLiteChartProto,
)
from streamlit.runtime.metrics_util import gather_metrics
from streamlit.util import HASHLIB_KWARGS

if TYPE_CHECKING:
import altair as alt
Expand Down Expand Up @@ -128,13 +131,25 @@ def _marshall_chart_data(
These operations will happen in-place."""

# Pull data out of spec dict when it's in a 'datasets' key:
# datasets: {foo: df1, bar: df2}, ...}
# datasets: {foo: df1_bytes, bar: df2_bytes}, ...}
if "datasets" in spec:
for dataset_name, dataset_data in spec["datasets"].items():
dataset = proto.datasets.add()
dataset.name = str(dataset_name)
dataset.has_name = True
dataset.data.data = _serialize_data(dataset_data)
# The ID transformer (id_transform function registered before conversion to dict)
# already serializes the data into Arrow IPC format (bytes) when the Altair object
# gets converted into the vega-lite spec dict.
# If its already in bytes, we don't need to serialize it here again.
# We just need to pass the data information into the correct proto fields.

# TODO(lukasmasuch): Are there any other cases where we need to serialize the data
# or can we remove the _serialize_data here?
dataset.data.data = (
dataset_data
if isinstance(dataset_data, bytes)
else _serialize_data(dataset_data)
)
del spec["datasets"]

# Pull data out of spec dict when it's in a top-level 'data' key:
Expand Down Expand Up @@ -169,11 +184,21 @@ def _convert_altair_to_vega_lite_spec(altair_chart: alt.Chart) -> dict[str, Any]
datasets = {}

def id_transform(data) -> dict[str, str]:
"""Altair data transformer that returns a fake named dataset with the
object id.
"""Altair data transformer that serializes the data,
creates a stable name based on the hash of the data,
stores the bytes into the datasets mapping and
returns this name to have it be used in Altair.
"""
name = str(id(data))
datasets[name] = data

# Already serialize the data to be able to create a stable
# dataset name:
data_bytes = _serialize_data(data)
# Use the md5 hash of the data as the name:
h = hashlib.new("md5", **HASHLIB_KWARGS)
h.update(str(data_bytes).encode("utf-8"))
name = h.hexdigest()

datasets[name] = data_bytes
return {"name": name}

alt.data_transformers.register("id", id_transform) # type: ignore[attr-defined,unused-ignore]
Expand All @@ -185,12 +210,94 @@ def id_transform(data) -> dict[str, str]:
with alt.data_transformers.enable("id"): # type: ignore[attr-defined,unused-ignore]
chart_dict = altair_chart.to_dict()

# Put datasets back into the chart dict but note how they weren't
# transformed.
# Put datasets back into the chart dict:
chart_dict["datasets"] = datasets
return chart_dict


def _reset_counter_pattern(prefix: str, vega_spec: str) -> str:
"""Altair uses a global counter for unnamed parameters and views.
We need to reset these counters on a spec-level to make the
spec stable across reruns and avoid changes to the element ID.
"""
pattern = re.compile(rf'"{prefix}\d+"')
# Get all matches without duplicates in order of appearance.
# Using a set here would not guarantee the order of appearance,
# which might lead to different replacements on each run.
# The order of the spec from Altair is expected to stay stable
# within the same session / Altair version.
# The order might change with Altair updates, but that's not really
# a case that is relevant for us since we mainly care about having
# this stable within a session.
if matches := list(dict.fromkeys(pattern.findall(vega_spec))):
# Add a prefix to the replacement to avoid
# replacing instances that already have been replaced before.
# The prefix here is arbitrarily chosen with the main goal
# that its extremely unlikely to already be part of the spec:
replacement_prefix = "__replace_prefix_o9hd101n22e1__"

# Replace all matches with a counter starting from 1
# We start from 1 to imitate the altair behavior.
for counter, match in enumerate(matches, start=1):
vega_spec = vega_spec.replace(
match, f'"{replacement_prefix}{prefix}{counter}"'
)

# Remove the prefix again from all replacements:
vega_spec = vega_spec.replace(replacement_prefix, "")
return vega_spec


def _stabilize_vega_json_spec(vega_spec: str) -> str:
"""Makes the chart spec stay stable across reruns and sessions.
Altair auto creates names for unnamed parameters & views. It uses a global counter
for the naming which will result in a different spec on every rerun.
In Streamlit, we need the spec to be stable across reruns and sessions to prevent the chart
from getting a new identity. So we need to replace the names with counter with a stable name.
Having a stable chart spec is also important for features like forward message cache,
where we don't want to have changing messages on every rerun.
Parameter counter:
https://github.com/vega/altair/blob/f345cd9368ae2bbc98628e9245c93fa9fb582621/altair/vegalite/v5/api.py#L196
View counter:
https://github.com/vega/altair/blob/f345cd9368ae2bbc98628e9245c93fa9fb582621/altair/vegalite/v5/api.py#L2885
This is temporary solution waiting for a fix for this issue:
https://github.com/vega/altair/issues/3416
Other solutions we considered:
- working on the dict object: this would require to iterate through the object and do the
same kind of replacement; though we would need to know the structure and since we need
the spec in String-format anyways, we deemed that executing the replacement on the
String is the better alternative
- resetting the counter: the counter is incremented already when the chart object is created
(see this GitHub issue comment https://github.com/vega/altair/issues/3416#issuecomment-2098530464),
so it would be too late here to reset the counter with a thread-lock to prevent interference
between sessions
"""

# We only want to apply these replacements if it is really necessary
# since there is a risk that we replace names that where chosen by the user
# and thereby introduce unwanted side effects.

# We only need to apply the param_ fix if there are actually parameters defined
# somewhere in the spec. We can check for this by looking for the '"params"' key.
# This isn't a perfect check, but good enough to prevent unnecessary executions
# for the majority of charts.
if '"params"' in vega_spec:
vega_spec = _reset_counter_pattern("param_", vega_spec)

# Simple check if the spec contains a composite chart:
# https://vega.github.io/vega-lite/docs/composition.html
# Other charts will not contain the `view_` name,
# so its better to not replace this pattern.
if re.search(r'"(vconcat|hconcat|facet|layer|concat|repeat)"', vega_spec):
vega_spec = _reset_counter_pattern("view_", vega_spec)
return vega_spec


class VegaChartsMixin:
"""Mix-in class for all vega-related chart commands.
Expand Down Expand Up @@ -1020,7 +1127,8 @@ def _vega_lite_chart(
spec = _prepare_vega_lite_spec(spec, use_container_width, **kwargs)
_marshall_chart_data(proto, spec, data)

proto.spec = json.dumps(spec)
# Prevent the spec from changing across reruns:
proto.spec = _stabilize_vega_json_spec(json.dumps(spec))
proto.use_container_width = use_container_width
proto.theme = theme or ""

Expand Down
53 changes: 50 additions & 3 deletions lib/streamlit/type_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,11 +782,15 @@ def is_pandas_version_less_than(v: str) -> bool:
-------
bool
Raises
------
InvalidVersion
If the version strings are not valid.
"""
import pandas as pd
from packaging import version

return version.parse(pd.__version__) < version.parse(v)
return is_version_less_than(pd.__version__, v)


def is_pyarrow_version_less_than(v: str) -> bool:
Expand All @@ -801,11 +805,54 @@ def is_pyarrow_version_less_than(v: str) -> bool:
-------
bool
Raises
------
InvalidVersion
If the version strings are not valid.
"""
import pyarrow as pa

return is_version_less_than(pa.__version__, v)


def is_altair_version_less_than(v: str) -> bool:
"""Return True if the current Altair version is less than the input version.
Parameters
----------
v : str
Version string, e.g. "0.25.0"
Returns
-------
bool
Raises
------
InvalidVersion
If the version strings are not valid.
"""
import altair as alt

return is_version_less_than(alt.__version__, v)


def is_version_less_than(v1: str, v2: str) -> bool:
"""Return True if the v1 version string is less than the v2 version string
based on semantic versioning.
Raises
------
InvalidVersion
If the version strings are not valid.
"""
from packaging import version

return version.parse(pa.__version__) < version.parse(v)
return version.parse(v1) < version.parse(v2)


def _maybe_truncate_table(
Expand Down
Loading

0 comments on commit 9f28b0a

Please sign in to comment.