diff --git a/src/moscot/plotting/_utils.py b/src/moscot/plotting/_utils.py index c3a6f5f62..4ca12184f 100644 --- a/src/moscot/plotting/_utils.py +++ b/src/moscot/plotting/_utils.py @@ -60,6 +60,8 @@ def _sankey( ) -> mpl.figure.Figure: if ax is None: fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize) + else: + fig = ax.figure if captions is not None and len(captions) != len(transition_matrices): raise ValueError(f"Expected captions to be of length `{len(transition_matrices)}`, found `{len(captions)}`.") if colorDict is None: @@ -206,6 +208,8 @@ def _heatmap( if ax is None: fig, ax = plt.subplots(constrained_layout=True, dpi=dpi, figsize=figsize) + else: + fig = ax.figure if row_annotation != AggregationMode.CELL: set_palette(adata=row_adata, key=row_annotation, cont_cmap=cont_cmap) if col_annotation != AggregationMode.CELL: