Skip to content

Commit

Permalink
add 3d pca scatterplot util
Browse files Browse the repository at this point in the history
  • Loading branch information
matiaslindgren committed Nov 12, 2020
1 parent 61417dc commit 49c27a0
Showing 1 changed file with 20 additions and 0 deletions.
20 changes: 20 additions & 0 deletions lidbox/visualize.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import numpy as np
import pandas as pd
import seaborn as sns


Expand Down Expand Up @@ -110,3 +112,21 @@ def plot_embedding_vector(v, cmap="RdBu_r", figsize=None):

plt.gcf().set_size_inches(*figsize)
plt.show()


def draw_3d_pca_scatterplot(pca_data_3d, data_labels):
df = pd.DataFrame.from_dict({
"x": pca_data_3d[:,0],
"y": pca_data_3d[:,1],
"z": pca_data_3d[:,2],
"label": data_labels,
})

fig = plt.figure()
ax = fig.add_subplot(111, projection="3d")

for label, group in df.groupby("label"):
ax.scatter(group.x, group.y, group.z, label=label)

ax.legend()
return fig, ax

0 comments on commit 49c27a0

Please sign in to comment.