Skip to content

Commit

Permalink
edgecolor fix (#582)
Browse files Browse the repository at this point in the history
  • Loading branch information
jordanplanders authored Jun 18, 2024
1 parent 6b79c6b commit 3080ce3
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 28 deletions.
6 changes: 5 additions & 1 deletion pyleoclim/core/geoseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import matplotlib.pyplot as plt
import re
import pandas as pd
import numpy as np

#from copy import deepcopy
from matplotlib import gridspec
Expand Down Expand Up @@ -497,9 +498,12 @@ def map_neighbors(self, mgs, radius=3000, projection='Orthographic', proj_defaul
neighbor_coloring = ['w' for ik in range(len(neighborhood))]
neighbor_coloring[-1] = 'k'
neighborhood['original'] =neighbor_coloring
neighborhood['neighbors'] =neighbor_coloring

# plot neighbors

fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size, marker=marker, projection=projection,
fig, ax_d = mapping.scatter_map(neighborhood, fig=fig, gs_slot=gridspec_slot, hue=hue, size=size,
marker=marker, projection=projection,
proj_default=proj_default,
background=background, borders=borders, rivers=rivers, lakes=lakes,
ocean=ocean, land=land,
Expand Down
96 changes: 69 additions & 27 deletions pyleoclim/utils/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -718,7 +718,7 @@ def scatter_map(geos, hue='archiveType', size=None, marker='archiveType', edgeco
# geos_df = pd.DataFrame(value_d)
# return geos_df

def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor='w',
def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_var=None, edgecolor_var=None, edgecolor='w',
ax=None, ax_d=None, proj=None, scatter_kwargs=None, legend=True, lgd_kwargs=None, colorbar=None,
fig=None, color_scale_type=None, # gs_slot=None,
cmap=None, **kwargs):
Expand Down Expand Up @@ -773,16 +773,28 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
}
missing_val = missing_d['label']

# if 'edgecolor' in scatter_kwargs:
# edgecolor = scatter_kwargs['edgecolor']
# scatter_kwargs['edgecolors'] = edgecolor
# if 'edgecolors' not in scatter_kwargs:
# scatter_kwargs['edgecolors'] = edgecolor
#
if 'edgecolor' in scatter_kwargs:
edgecolor = scatter_kwargs['edgecolor']# = edgecolor
if isinstance(scatter_kwargs, dict):
edgecolor = scatter_kwargs.pop('edgecolor', edgecolor)

if 'neighbor' in df.columns:
edgecolor_var = 'neighbor'
# if ~isinstance(edgecolor, np.ndarray):
# if isinstance(edgecolor, str):
# edgecolor = [edgecolor]
# edgecolor = np.array(edgecolor)


if isinstance(edgecolor, (list, np.ndarray)):
if len(edgecolor) == len(_df):
_df['edgecolor'] = edgecolor
elif len(edgecolor) == 1:
_df['edgecolor'] = edgecolor[0]
# try making a column populated by the edgecolor
elif isinstance(edgecolor, str):
_df['edgecolor'] = edgecolor
elif isinstance(edgecolor, dict):
if edgecolor_var in _df.columns:
_df['edgecolor'] = _df[edgecolor_var].map(edgecolor)

hue_var = hue_var if hue_var in _df.columns else None
hue_var_type_numeric = False
Expand Down Expand Up @@ -935,23 +947,42 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
_df['edgecolor'] = edgecolor
_df['neighbor'] = _df['edgecolor'].map({'k': 'target', 'w': 'neighbor'})

# handle missing values
hue_data = _df[_df[hue_var] == missing_val]

if len(hue_data) > 0:
sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var,
style=marker_var, transform=transform,edgecolor='w',
# change to transform=scatter_kwargs['transform']
if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values, linewidth=2,
style=marker_var, hue=hue_var, palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, legend=False,
**scatter_kwargs)

sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=None, linewidth=0,
style=marker_var, hue=hue_var,
palette=[missing_d['hue'] for ik in range(len(hue_data))],
ax=ax, **scatter_kwargs)
ax=ax,
**scatter_kwargs)
# sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var,
# style=marker_var, transform=transform,edgecolor=edgecolor,
# # change to transform=scatter_kwargs['transform']
# palette=[missing_d['hue'] for ik in range(len(hue_data))],
# ax=ax, **scatter_kwargs)
missing_handles, missing_labels = ax.get_legend_handles_labels()
if 'neighbor' in hue_data.columns:
if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1:
_edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0]
else:
_edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor']
sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
transform=transform, edgecolor=_edgecolor,
style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)

# # if the missing values are being handled for map_neighbors
# if 'neighbor' in hue_data.columns:
# # if the missing value is actually the target
# if len(hue_data[hue_data['neighbor'] != 'neighbor']) > 1:
# _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0]
# else:
# _edgecolor = hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor']
# sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
# transform=transform, edgecolor=_edgecolor,
# style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)
# available values
else:
missing_handles, missing_labels = [], []

Expand All @@ -960,12 +991,23 @@ def plot_scatter(df=None, x=None, y=None, hue_var=None, size_var=None, marker_va
if hue_norm is not None:
scatter_kwargs['hue_norm'] = hue_norm

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor='w',
style=marker_var, palette=palette, ax=ax, **scatter_kwargs)
# sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,edgecolor=edgecolor,
# style=marker_var, palette=palette, ax=ax, **scatter_kwargs)

if 'neighbor' in hue_data.columns:
sns.scatterplot(data=hue_data[hue_data['neighbor'] != 'neighbor'], x=x, y=y, size=size_var,
transform=transform, edgecolor=hue_data[hue_data['neighbor'] != 'neighbor']['edgecolor'].values[0],
style=marker_var, hue=hue_var, palette=palette, ax=ax, **scatter_kwargs)
sns.scatterplot(data=hue_data, x=x, y=y, size=size_var,
transform=transform,
edgecolor=hue_data.edgecolor.values,linewidth=2,
style=marker_var, hue=hue_var, palette=palette, ax=ax, legend=False, **scatter_kwargs)
if not isinstance(edgecolor, str):
edgecolor = None
linewidth = 0
else:
linewidth = 1

sns.scatterplot(data=hue_data, x=x, y=y, hue=hue_var, size=size_var, transform=transform,
edgecolor=edgecolor,linewidth=linewidth,
style=marker_var, palette=palette, ax=ax, **scatter_kwargs)

else:
scatter_kwargs['zorder'] = 13
Expand Down

0 comments on commit 3080ce3

Please sign in to comment.