Skip to content

Commit

Permalink
Leaderboard: Refined plots (#1601)
Browse files Browse the repository at this point in the history
* Added embedding size guide to performance-size plot, removed shading on radar chart

* Changed plot names to something more descriptive

* Made plots failsafe
  • Loading branch information
x-tabdeveloping authored Dec 16, 2024
1 parent 95d5ae5 commit 0c9e046
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 9 deletions.
4 changes: 2 additions & 2 deletions mteb/leaderboard/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,12 +218,12 @@ def update_task_info(task_names: str) -> gr.DataFrame:
)
citation = gr.Markdown(update_citation, inputs=[benchmark_select])
with gr.Column():
with gr.Tab("Performance-Size Plot"):
with gr.Tab("Performance per Model Size"):
plot = gr.Plot(performance_size_plot, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all tasks in the benchmark*"
)
with gr.Tab("Top 5 Radar Chart"):
with gr.Tab("Performance per Task Type (Radar Chart)"):
radar_plot = gr.Plot(radar_chart, inputs=[summary_table])
gr.Markdown(
"*We only display models that have been run on all task types in the benchmark*"
Expand Down
95 changes: 88 additions & 7 deletions mteb/leaderboard/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,28 @@
import plotly.graph_objects as go


def text_plot(text: str):
"""Returns empty scatter plot with text added, this can be great for error messages."""
return px.scatter(template="plotly_white").add_annotation(
text=text, showarrow=False, font=dict(size=20)
)


def failsafe_plot(fun):
"""Decorator that turns the function producing a figure failsafe.
This is necessary, because once a Callback encounters an exception it
becomes useless in Gradio.
"""

def wrapper(*args, **kwargs):
try:
return fun(*args, **kwargs)
except Exception:
return text_plot("Couldn't produce plot.")

return wrapper


def parse_n_params(text: str) -> int:
if text.endswith("M"):
return float(text[:-1]) * 1e6
Expand Down Expand Up @@ -37,6 +59,48 @@ def parse_float(value) -> float:
]


def add_size_guide(fig: go.Figure):
xpos = [5 * 1e9] * 4
ypos = [7.8, 8.5, 9, 10]
sizes = [256, 1024, 2048, 4096]
fig.add_trace(
go.Scatter(
showlegend=False,
opacity=0.3,
mode="markers",
marker=dict(
size=np.sqrt(sizes),
color="rgba(0,0,0,0)",
line=dict(color="black", width=2),
),
x=xpos,
y=ypos,
)
)
fig.add_annotation(
text="<b>Embedding Size:</b>",
font=dict(size=16),
x=np.log10(1.5e9),
y=10,
showarrow=False,
opacity=0.3,
)
for x, y, size in zip(xpos, np.linspace(7.5, 14, 4), sizes):
fig.add_annotation(
text=f"<b>{size}</b>",
font=dict(size=12),
x=np.log10(x),
y=y,
showarrow=True,
ay=0,
ax=50,
opacity=0.3,
arrowwidth=2,
)
return fig


@failsafe_plot
def performance_size_plot(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Number of Parameters"] = df["Number of Parameters"].map(parse_n_params)
Expand All @@ -50,14 +114,15 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
if not len(df.index):
return go.Figure()
min_score, max_score = df["Mean (Task)"].min(), df["Mean (Task)"].max()
df["sqrt(dim)"] = np.sqrt(df["Embedding Dimensions"])
fig = px.scatter(
df,
x="Number of Parameters",
y="Mean (Task)",
log_x=True,
template="plotly_white",
text="model_text",
size="Embedding Dimensions",
size="sqrt(dim)",
color="Log(Tokens)",
range_color=[2, 5],
range_x=[8 * 1e6, 11 * 1e9],
Expand All @@ -69,10 +134,21 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"Mean (Task)": True,
"Rank (Borda)": True,
"Log(Tokens)": False,
"sqrt(dim)": False,
"model_text": False,
},
hover_name="Model",
)
# Note: it's important that this comes before setting the size mode
fig = add_size_guide(fig)
fig.update_traces(
marker=dict(
sizemode="diameter",
sizeref=1.5,
sizemin=0,
)
)
fig.add_annotation(x=1e9, y=10, text="Model size:")
fig.update_layout(
coloraxis_colorbar=dict( # noqa
title="Max Tokens",
Expand Down Expand Up @@ -124,21 +200,26 @@ def performance_size_plot(df: pd.DataFrame) -> go.Figure:
"#3CBBB1",
]
fill_colors = [
"rgba(238,66,102,0.2)",
"rgba(0,166,237,0.2)",
"rgba(236,167,44,0.2)",
"rgba(180,35,24,0.2)",
"rgba(60,187,177,0.2)",
"rgba(238,66,102,0.05)",
"rgba(0,166,237,0.05)",
"rgba(236,167,44,0.05)",
"rgba(180,35,24,0.05)",
"rgba(60,187,177,0.05)",
]


@failsafe_plot
def radar_chart(df: pd.DataFrame) -> go.Figure:
df = df.copy()
df["Model"] = df["Model"].map(parse_model_name)
# Remove whitespace
task_type_columns = [
column for column in df.columns if "".join(column.split()) in task_types
]
if len(task_type_columns) <= 1:
raise ValueError(
"Couldn't produce radar chart, the benchmark only contains one task category."
)
df = df[["Model", *task_type_columns]].set_index("Model")
df = df.replace("", np.nan)
df = df.dropna()
Expand All @@ -156,7 +237,7 @@ def radar_chart(df: pd.DataFrame) -> go.Figure:
mode="lines",
line=dict(width=2, color=line_colors[i]),
fill="toself",
fillcolor=fill_colors[i],
fillcolor="rgba(0,0,0,0)",
)
)
fig.update_layout(
Expand Down

0 comments on commit 0c9e046

Please sign in to comment.