diff --git a/pymde/experiment_utils.py b/pymde/experiment_utils.py index 940bb09..889adbe 100644 --- a/pymde/experiment_utils.py +++ b/pymde/experiment_utils.py @@ -150,7 +150,6 @@ def _is_discrete(dtype): def _plot_3d( X, color_by, cmap, colors, edges, s, background_color, figsize, lim ): - from mpl_toolkits.mplot3d import Axes3D if isinstance(X, torch.Tensor): X = X.cpu().numpy() @@ -159,7 +158,7 @@ def _plot_3d( shadowcolor = "gainsboro" fig = plt.figure(figsize=figsize) - ax = Axes3D(fig) + ax = fig.add_subplot(projection='3d') x, y, z = X[:, 0], X[:, 1], X[:, 2] @@ -196,27 +195,26 @@ def _plot_3d( ax.set_zlim(lim) if edges is None: - ax.plot( + # shadows + ax.scatter( y, z, - "g+", zdir="x", zs=ax.axes.get_xlim3d()[0], c=shadowcolor, alpha=0.5, marker="o", - markersize=shadowsize, + s=shadowsize, ) - ax.plot( + ax.scatter( x, y, - "k+", zdir="z", zs=ax.axes.get_zlim3d()[0], c=shadowcolor, alpha=0.5, marker="o", - markersize=shadowsize, + s=shadowsize, ) if background_color is not None: