forked from benfulcher/hctsaAnalysisPython
-
Notifications
You must be signed in to change notification settings - Fork 0
/
umap_projection.py
55 lines (49 loc) · 2.21 KB
/
umap_projection.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
#-------------------------------------------------------------------------------
# Visualize labeled HCTSA dataset as a UMAP projection
# Learn about umap here:
# http://umap-learn.readthedocs.io/en/latest/basic_usage.html
#-------------------------------------------------------------------------------
import matplotlib as plt
import seaborn as sns
import pandas as pd
import umap
sns.set(style='white', context='notebook', rc={'figure.figsize':(14,10)})
# Use OutputToCSV to generate the following (default) files:
dataMatrixCSV = 'hctsa_datamatrix.csv'
timeSeriesInfoCSV = 'hctsa_timeseries-info.csv'
#-------------------------------------------------------------------------------
def LoadResults():
# Load relevant hctsa results as a dataframe ():
dataMatrix = pd.read_csv(dataMatrixCSV,header=None)
tsLabels = pd.read_csv(timeSeriesInfoCSV,header=None,names=('name','label'))
tsLabels['label'] = tsLabels['label'].astype('category')
return dataMatrix,tsLabels
#-------------------------------------------------------------------------------
def UMAP_embed(dataMatrix):
# Compute a 2d umap projection of the data in dataMatrix
reducer = umap.UMAP()
embedding = reducer.fit_transform(dataMatrix)
df = pd.DataFrame(data=embedding,columns=('umap-1','umap-2'))
return df
#-------------------------------------------------------------------------------
def plot_projection(df,doSave=True):
# Plot two-dimensional projection of data
lowDim = sns.lmplot(x='umap-1', y='umap-2', data=df, fit_reg=False, markers='.',
hue='label', legend=True, legend_out=True, palette='Set2')
# ax = plt.gca()
# ax.set_title('UMAP projection of the hctsa dataset', fontsize=24);
# ax.set_aspect('equal', 'datalim')
if doSave:
lowDim.savefig('umapProjection.pdf')
return lowDim
#-------------------------------------------------------------------------------
def main():
# Load data from hctsa calculation:
dataMatrix,tsLabels = LoadResults()
# Compute a umap projection in a dataframe:
df_umap = UMAP_embed(dataMatrix)
# Add label information:
df_umap['label'] = tsLabels['label']
# Plot
plot_projection(df_umap)
if __name__ == "__main__": main()