Skip to content

Commit

Permalink
Fix 3d plotting (#70)
Browse files Browse the repository at this point in the history
Creating Axes3D directly silently fails in newer versions of matplotlib.
Use the current recommended method of adding a 3d projection to a figure
instead.
  • Loading branch information
akshayka authored Nov 20, 2022
1 parent 7b926a0 commit 9aea8f6
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions pymde/experiment_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def _is_discrete(dtype):
def _plot_3d(
X, color_by, cmap, colors, edges, s, background_color, figsize, lim
):
from mpl_toolkits.mplot3d import Axes3D

if isinstance(X, torch.Tensor):
X = X.cpu().numpy()
Expand All @@ -159,7 +158,7 @@ def _plot_3d(
shadowcolor = "gainsboro"

fig = plt.figure(figsize=figsize)
ax = Axes3D(fig)
ax = fig.add_subplot(projection='3d')

x, y, z = X[:, 0], X[:, 1], X[:, 2]

Expand Down Expand Up @@ -196,27 +195,26 @@ def _plot_3d(
ax.set_zlim(lim)

if edges is None:
ax.plot(
# shadows
ax.scatter(
y,
z,
"g+",
zdir="x",
zs=ax.axes.get_xlim3d()[0],
c=shadowcolor,
alpha=0.5,
marker="o",
markersize=shadowsize,
s=shadowsize,
)
ax.plot(
ax.scatter(
x,
y,
"k+",
zdir="z",
zs=ax.axes.get_zlim3d()[0],
c=shadowcolor,
alpha=0.5,
marker="o",
markersize=shadowsize,
s=shadowsize,
)

if background_color is not None:
Expand Down

0 comments on commit 9aea8f6

Please sign in to comment.