Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
Fixed formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
noaahh committed Jun 4, 2024
1 parent 1a0c07d commit 5242fa6
Showing 1 changed file with 30 additions and 27 deletions.
57 changes: 30 additions & 27 deletions notebooks/embedding_analysis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
"source": [
"import os\n",
"import sys\n",
"from dotenv import load_dotenv\n",
"import plotly.express as px\n",
"import plotly\n",
"\n",
"import pandas as pd\n",
"import plotly\n",
"import plotly.express as px\n",
"from dotenv import load_dotenv\n",
"\n",
"current_dir = os.getcwd()\n",
"parent_dir = os.path.dirname(current_dir)\n",
Expand Down Expand Up @@ -80,21 +81,23 @@
"import pandas as pd\n",
"from sklearn.decomposition import PCA\n",
"\n",
"\n",
"def break_content(text, length=50):\n",
" lines = []\n",
" while len(text) > length:\n",
" space_index = text.rfind(' ', 0, length)\n",
" if space_index == -1:\n",
" space_index = length\n",
" lines.append(text[:space_index])\n",
" text = text[space_index:].lstrip()\n",
" lines.append(text)\n",
" return '<br>'.join(lines)\n",
" lines = []\n",
" while len(text) > length:\n",
" space_index = text.rfind(' ', 0, length)\n",
" if space_index == -1:\n",
" space_index = length\n",
" lines.append(text[:space_index])\n",
" text = text[space_index:].lstrip()\n",
" lines.append(text)\n",
" return '<br>'.join(lines)\n",
"\n",
"\n",
"def plot_pca(weak_labelled, key):\n",
" if key not in weak_labelled:\n",
" raise ValueError(f\"File {key} not found in the weak_labelled dictionary.\")\n",
" \n",
"\n",
" df = weak_labelled[key]\n",
"\n",
" embeddings = np.vstack(df['embedding_vec'].values)\n",
Expand All @@ -103,12 +106,12 @@
"\n",
" pca = PCA(n_components=3)\n",
" reduced_embeddings = pca.fit_transform(embeddings)\n",
" \n",
"\n",
" pca_df = pd.DataFrame(reduced_embeddings, columns=['PCA1', 'PCA2', 'PCA3'])\n",
" pca_df['Label'] = labels\n",
" pca_df['Content'] = content\n",
"\n",
" fig = px.scatter_3d(pca_df, x='PCA1', y='PCA2', z='PCA3', color='Label', \n",
" fig = px.scatter_3d(pca_df, x='PCA1', y='PCA2', z='PCA3', color='Label',\n",
" title=f'PCA of Embedding Vectors for {key}',\n",
" size_max=5, opacity=0.6, height=800,\n",
" hover_data={'Content': True})\n",
Expand All @@ -134,8 +137,8 @@
"\n",
"knn_key = 'mlp_weak_labeling_weaklabels.parquet'\n",
"\n",
"knn_wl_ds = create_dataset('knn', weak_labelled[knn_key], \n",
" weak_labelled[knn_key]['embedding_vec'], \n",
"knn_wl_ds = create_dataset('knn', weak_labelled[knn_key],\n",
" weak_labelled[knn_key]['embedding_vec'],\n",
" weak_labelled[knn_key]['label'])\n",
"\n",
"px_session = launch_px(knn_wl_ds, None)\n",
Expand Down Expand Up @@ -198,9 +201,9 @@
"source": [
"log_reg_key = 'log_reg_weak_labeling_weaklabels.parquet'\n",
"\n",
"log_reg_wl_ds = create_dataset('log_reg', weak_labelled[log_reg_key], \n",
" weak_labelled[log_reg_key]['embedding_vec'], \n",
" weak_labelled[log_reg_key]['label'])\n",
"log_reg_wl_ds = create_dataset('log_reg', weak_labelled[log_reg_key],\n",
" weak_labelled[log_reg_key]['embedding_vec'],\n",
" weak_labelled[log_reg_key]['label'])\n",
"\n",
"px_session = launch_px(log_reg_wl_ds, None)\n",
"px_session.view()"
Expand Down Expand Up @@ -230,8 +233,8 @@
"source": [
"mlp_key = 'mlp_weak_labeling_weaklabels.parquet'\n",
"\n",
"mlp_wl_ds = create_dataset('mlp_reg', weak_labelled[mlp_key], \n",
" weak_labelled[mlp_key]['embedding_vec'], \n",
"mlp_wl_ds = create_dataset('mlp_reg', weak_labelled[mlp_key],\n",
" weak_labelled[mlp_key]['embedding_vec'],\n",
" weak_labelled[mlp_key]['label'])\n",
"\n",
"px_session = launch_px(mlp_wl_ds, None)\n",
Expand Down Expand Up @@ -262,9 +265,9 @@
"source": [
"rf_key = 'rf_weak_labeling_weaklabels.parquet'\n",
"\n",
"rf_wl_ds = create_dataset('rf_reg', weak_labelled[rf_key], \n",
" weak_labelled[rf_key]['embedding_vec'], \n",
" weak_labelled[rf_key]['label'])\n",
"rf_wl_ds = create_dataset('rf_reg', weak_labelled[rf_key],\n",
" weak_labelled[rf_key]['embedding_vec'],\n",
" weak_labelled[rf_key]['label'])\n",
"\n",
"px_session = launch_px(rf_wl_ds, None)\n",
"px_session.view()"
Expand Down Expand Up @@ -294,8 +297,8 @@
"source": [
"svm_key = 'svm_weak_labeling_weaklabels.parquet'\n",
"\n",
"svm_wl_ds = create_dataset('svm_reg', weak_labelled[svm_key], \n",
" weak_labelled[svm_key]['embedding_vec'], \n",
"svm_wl_ds = create_dataset('svm_reg', weak_labelled[svm_key],\n",
" weak_labelled[svm_key]['embedding_vec'],\n",
" weak_labelled[svm_key]['label'])\n",
"\n",
"px_session = launch_px(svm_wl_ds, None)\n",
Expand Down

0 comments on commit 5242fa6

Please sign in to comment.