Skip to content
This repository has been archived by the owner on Jun 28, 2024. It is now read-only.

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
noaahh committed Jun 7, 2024
2 parents 58f2528 + 8e433a6 commit 29bb03f
Show file tree
Hide file tree
Showing 20 changed files with 130 additions and 33 deletions.
2 changes: 1 addition & 1 deletion models/eval/eval_results.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"eval_loss": 0.6932736039161682, "eval_accuracy": 0.5, "eval_f1_macro": 0.4620003056568347, "eval_f1_weighted": 0.46200030565683464, "eval_precision": 0.5, "eval_recall": 0.5, "eval_auroc": 0.5, "eval_runtime": 2.6124, "eval_samples_per_second": 254.941, "eval_steps_per_second": 8.039}
{"eval_loss": 0.6934633851051331, "eval_accuracy": 0.48375451263537905, "eval_f1_macro": 0.4440795475278234, "eval_f1_weighted": 0.4435433993506942, "eval_precision": 0.478506455399061, "eval_recall": 0.48472526326764676, "eval_auroc": 0.48472526326764676, "eval_runtime": 1.3678, "eval_samples_per_second": 202.514, "eval_steps_per_second": 6.58}
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file removed models/semi-supervised/finetune_nested/eval_loss.png
Binary file not shown.
2 changes: 1 addition & 1 deletion models/semi-supervised/finetune_nested/eval_results.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"0.25": {"eval_loss": 0.4045407474040985, "eval_accuracy": 0.8948948948948949, "eval_f1_macro": 0.8948863636363636, "eval_f1_weighted": 0.8948863636363638, "eval_precision": 0.8950231387513192, "eval_recall": 0.8948948948948949, "eval_auroc": 0.8948948948948949, "eval_runtime": 2.3628, "eval_samples_per_second": 281.865, "eval_steps_per_second": 2.963, "epoch": 25.0}, "0.5": {"eval_loss": 0.44522979855537415, "eval_accuracy": 0.8933933933933934, "eval_f1_macro": 0.8933931530475364, "eval_f1_weighted": 0.8933931530475363, "eval_precision": 0.8933969410576438, "eval_recall": 0.8933933933933934, "eval_auroc": 0.8933933933933934, "eval_runtime": 2.3526, "eval_samples_per_second": 283.09, "eval_steps_per_second": 2.975, "epoch": 25.0}, "0.75": {"eval_loss": 0.5143420696258545, "eval_accuracy": 0.8828828828828829, "eval_f1_macro": 0.8828733766233766, "eval_f1_weighted": 0.8828733766233766, "eval_precision": 0.8830072257854997, "eval_recall": 0.8828828828828829, "eval_auroc": 0.8828828828828829, "eval_runtime": 2.353, "eval_samples_per_second": 283.043, "eval_steps_per_second": 2.975, "epoch": 25.0}, "1.0": {"eval_loss": 0.5319307446479797, "eval_accuracy": 0.8828828828828829, "eval_f1_macro": 0.8828818267080297, "eval_f1_weighted": 0.8828818267080298, "eval_precision": 0.8828966947738648, "eval_recall": 0.882882882882883, "eval_auroc": 0.8828828828828829, "eval_runtime": 2.3732, "eval_samples_per_second": 280.639, "eval_steps_per_second": 2.95, "epoch": 25.0}}
{"0.25": {"eval_loss": 0.32715415954589844, "eval_accuracy": 0.855595667870036, "eval_f1_macro": 0.8555486024196912, "eval_f1_weighted": 0.8555391893296223, "eval_precision": 0.8562291144527987, "eval_recall": 0.855671984151809, "eval_auroc": 0.855671984151809, "eval_runtime": 0.9815, "eval_samples_per_second": 282.211, "eval_steps_per_second": 3.056, "epoch": 25.0}, "0.5": {"eval_loss": 0.3914715349674225, "eval_accuracy": 0.8592057761732852, "eval_f1_macro": 0.8590220412637513, "eval_f1_weighted": 0.8590404147547047, "eval_precision": 0.860779384035198, "eval_recall": 0.8590866437284954, "eval_auroc": 0.8590866437284955, "eval_runtime": 0.9827, "eval_samples_per_second": 281.872, "eval_steps_per_second": 3.053, "epoch": 25.0}, "0.75": {"eval_loss": 0.5000463128089905, "eval_accuracy": 0.8411552346570397, "eval_f1_macro": 0.840987370838117, "eval_f1_weighted": 0.8410060223735529, "eval_precision": 0.8423338566195709, "eval_recall": 0.8410489000104264, "eval_auroc": 0.8410489000104265, "eval_runtime": 0.9822, "eval_samples_per_second": 282.008, "eval_steps_per_second": 3.054, "epoch": 25.0}, "1.0": {"eval_loss": 0.49611490964889526, "eval_accuracy": 0.8772563176895307, "eval_f1_macro": 0.877062447786132, "eval_f1_weighted": 0.8770800723228045, "eval_precision": 0.8792994966442953, "eval_recall": 0.8771243874465644, "eval_auroc": 0.8771243874465645, "eval_runtime": 0.9822, "eval_samples_per_second": 282.031, "eval_steps_per_second": 3.054, "epoch": 25.0}}
Binary file not shown.
Binary file not shown.
Binary file not shown.
2 changes: 1 addition & 1 deletion models/supervised/finetune_nested/eval_results.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"0.25": {"eval_loss": 0.691838800907135, "eval_accuracy": 0.5675675675675675, "eval_f1_macro": 0.5675675675675675, "eval_f1_weighted": 0.5675675675675675, "eval_precision": 0.5675675675675675, "eval_recall": 0.5675675675675675, "eval_auroc": 0.5675675675675675, "eval_runtime": 2.3948, "eval_samples_per_second": 278.108, "eval_steps_per_second": 2.923, "epoch": 25.0}, "0.5": {"eval_loss": 0.6798407435417175, "eval_accuracy": 0.6816816816816816, "eval_f1_macro": 0.6765360824742268, "eval_f1_weighted": 0.6765360824742268, "eval_precision": 0.6940279102019589, "eval_recall": 0.6816816816816818, "eval_auroc": 0.6816816816816816, "eval_runtime": 2.3669, "eval_samples_per_second": 281.381, "eval_steps_per_second": 2.957, "epoch": 25.0}, "0.75": {"eval_loss": 0.3387678563594818, "eval_accuracy": 0.8843843843843844, "eval_f1_macro": 0.8843778676124103, "eval_f1_weighted": 0.8843778676124106, "eval_precision": 0.8844710636455477, "eval_recall": 0.8843843843843844, "eval_auroc": 0.8843843843843844, "eval_runtime": 2.3506, "eval_samples_per_second": 283.327, "eval_steps_per_second": 2.978, "epoch": 25.0}, "1.0": {"eval_loss": 0.2984059453010559, "eval_accuracy": 0.8993993993993994, "eval_f1_macro": 0.8992992789682137, "eval_f1_weighted": 0.8992992789682136, "eval_precision": 0.9009941329856584, "eval_recall": 0.8993993993993994, "eval_auroc": 0.8993993993993994, "eval_runtime": 2.3504, "eval_samples_per_second": 283.362, "eval_steps_per_second": 2.978, "epoch": 25.0}}
{"0.25": {"eval_loss": 0.6932923793792725, "eval_accuracy": 0.49097472924187724, "eval_f1_macro": 0.4247271441827582, "eval_f1_weighted": 0.42402238263957603, "eval_precision": 0.48541747951619196, "eval_recall": 0.49220623501199046, "eval_auroc": 0.4922062350119904, "eval_runtime": 0.9797, "eval_samples_per_second": 282.736, "eval_steps_per_second": 3.062, "epoch": 25.0}, "0.5": {"eval_loss": 0.691850483417511, "eval_accuracy": 0.5451263537906137, "eval_f1_macro": 0.5025937749401437, "eval_f1_weighted": 0.5031188685061988, "eval_precision": 0.5665643205794363, "eval_recall": 0.5440777812532582, "eval_auroc": 0.5440777812532582, "eval_runtime": 0.9804, "eval_samples_per_second": 282.536, "eval_steps_per_second": 3.06, "epoch": 25.0}, "0.75": {"eval_loss": 0.6915454268455505, "eval_accuracy": 0.5451263537906137, "eval_f1_macro": 0.48733842538190364, "eval_f1_weighted": 0.4879598009561908, "eval_precision": 0.5792866553736119, "eval_recall": 0.5439213846314253, "eval_auroc": 0.5439213846314253, "eval_runtime": 0.9814, "eval_samples_per_second": 282.253, "eval_steps_per_second": 3.057, "epoch": 25.0}, "1.0": {"eval_loss": 0.6900911927223206, "eval_accuracy": 0.5992779783393501, "eval_f1_macro": 0.5907439204568142, "eval_f1_weighted": 0.5909572719038776, "eval_precision": 0.6075076608784473, "eval_recall": 0.5987644666875195, "eval_auroc": 0.5987644666875196, "eval_runtime": 0.982, "eval_samples_per_second": 282.09, "eval_steps_per_second": 3.055, "epoch": 25.0}}
Binary file removed models/supervised/transfer_nested/eval_accuracy.png
Binary file not shown.
Binary file removed models/supervised/transfer_nested/eval_f1_macro.png
Binary file not shown.
Binary file not shown.
Binary file removed models/supervised/transfer_nested/eval_loss.png
Binary file not shown.
2 changes: 1 addition & 1 deletion models/supervised/transfer_nested/eval_results.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"0.25": {"eval_loss": 0.6932699680328369, "eval_accuracy": 0.496996996996997, "eval_f1_macro": 0.4623790648080097, "eval_f1_weighted": 0.4623790648080097, "eval_precision": 0.49595520357594014, "eval_recall": 0.49699699699699695, "eval_auroc": 0.49699699699699695, "eval_runtime": 2.3519, "eval_samples_per_second": 283.169, "eval_steps_per_second": 2.976, "epoch": 25.0}, "0.5": {"eval_loss": 0.6928256154060364, "eval_accuracy": 0.524024024024024, "eval_f1_macro": 0.48958361035425635, "eval_f1_weighted": 0.48958361035425635, "eval_precision": 0.532905138339921, "eval_recall": 0.524024024024024, "eval_auroc": 0.524024024024024, "eval_runtime": 2.3686, "eval_samples_per_second": 281.18, "eval_steps_per_second": 2.955, "epoch": 25.0}, "0.75": {"eval_loss": 0.6927742958068848, "eval_accuracy": 0.521021021021021, "eval_f1_macro": 0.4888841308066313, "eval_f1_weighted": 0.4888841308066313, "eval_precision": 0.5280843373493975, "eval_recall": 0.5210210210210211, "eval_auroc": 0.521021021021021, "eval_runtime": 2.35, "eval_samples_per_second": 283.403, "eval_steps_per_second": 2.979, "epoch": 25.0}, "1.0": {"eval_loss": 0.6926871538162231, "eval_accuracy": 0.5195195195195195, "eval_f1_macro": 0.48930296756383707, "eval_f1_weighted": 0.48930296756383707, "eval_precision": 0.5255715045188729, "eval_recall": 0.5195195195195195, "eval_auroc": 0.5195195195195196, "eval_runtime": 2.3534, "eval_samples_per_second": 282.994, "eval_steps_per_second": 2.974, "epoch": 25.0}}
{"0.25": {"eval_loss": 0.6934651136398315, "eval_accuracy": 0.48375451263537905, "eval_f1_macro": 0.4440795475278234, "eval_f1_weighted": 0.4435433993506942, "eval_precision": 0.478506455399061, "eval_recall": 0.48472526326764676, "eval_auroc": 0.48472526326764676, "eval_runtime": 0.9834, "eval_samples_per_second": 281.682, "eval_steps_per_second": 3.051, "epoch": 25.0}, "0.5": {"eval_loss": 0.6924956440925598, "eval_accuracy": 0.5342960288808665, "eval_f1_macro": 0.4941892561398542, "eval_f1_weighted": 0.4947034455339697, "eval_precision": 0.5484593199757134, "eval_recall": 0.5332864143467835, "eval_auroc": 0.5332864143467835, "eval_runtime": 0.9856, "eval_samples_per_second": 281.049, "eval_steps_per_second": 3.044, "epoch": 25.0}, "0.75": {"eval_loss": 0.6925008893013, "eval_accuracy": 0.5342960288808665, "eval_f1_macro": 0.4941892561398542, "eval_f1_weighted": 0.4947034455339697, "eval_precision": 0.5484593199757134, "eval_recall": 0.5332864143467835, "eval_auroc": 0.5332864143467835, "eval_runtime": 0.9798, "eval_samples_per_second": 282.72, "eval_steps_per_second": 3.062, "epoch": 25.0}, "1.0": {"eval_loss": 0.6925026774406433, "eval_accuracy": 0.5342960288808665, "eval_f1_macro": 0.4941892561398542, "eval_f1_weighted": 0.4947034455339697, "eval_precision": 0.5484593199757134, "eval_recall": 0.5332864143467835, "eval_auroc": 0.5332864143467835, "eval_runtime": 0.9818, "eval_samples_per_second": 282.129, "eval_steps_per_second": 3.056, "epoch": 25.0}}
Binary file removed models/supervised/transfer_nested/eval_runtime.png
Binary file not shown.
Binary file not shown.
Binary file not shown.
122 changes: 98 additions & 24 deletions notebooks/main.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,7 @@
"metadata": {},
"cell_type": "code",
"source": [
"from sklearn.metrics import roc_curve\n",
"import json\n",
"import matplotlib.pyplot as plt\n",
"\n",
Expand All @@ -364,7 +365,7 @@
"\n",
"MODEL_DIR = os.getenv(\"MODELS_DIR\")\n",
"\n",
"def plot_model_performance(results_data, model_names, baseline_data=None, metrics=None):\n",
"def plot_model_performance(results_data, model_names, baseline_data=None, metrics=None, baseline_name='Baseline'):\n",
" if metrics is None:\n",
" metrics = ['eval_accuracy', 'eval_f1_macro', 'eval_f1_weighted']\n",
"\n",
Expand All @@ -378,7 +379,7 @@
"\n",
" if baseline_data:\n",
" baseline_accuracy = baseline_data['eval_accuracy']\n",
" ax.axhline(y=baseline_accuracy, color='r', linestyle='--', label='Baseline')\n",
" ax.axhline(y=baseline_accuracy, color='r', linestyle='--', label=baseline_name)\n",
"\n",
" for i, model_results in enumerate(results):\n",
" values = [model_results[fraction]['eval_accuracy'] for fraction in fractions]\n",
Expand All @@ -389,6 +390,36 @@
" ax.set_xlabel('Fraction of Labeled Samples')\n",
" ax.set_ylabel('Accuracy')\n",
" ax.set_title('Model Accuracy Comparison')\n",
" ax.set_ylim(0, 1)\n",
" ax.legend(loc='lower right')\n",
" ax.grid(True)\n",
"\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"def plot_model_auroc(results_data, model_names, baseline_data=None, baseline_name='Baseline'):\n",
" results = results_data\n",
"\n",
" fractions = sorted(results[0].keys(), key=float)\n",
" num_fractions = len(fractions)\n",
"\n",
" fig, ax = plt.subplots(figsize=(10, 6))\n",
" x = range(num_fractions)\n",
"\n",
" if baseline_data:\n",
" baseline_auroc = baseline_data['eval_auroc']\n",
" ax.axhline(y=baseline_auroc, color='r', linestyle='--', label=baseline_name)\n",
"\n",
" for i, model_results in enumerate(results):\n",
" values = [model_results[fraction]['eval_auroc'] for fraction in fractions]\n",
" ax.plot(x, values, marker='o', label=model_names[i])\n",
"\n",
" ax.set_xticks(x)\n",
" ax.set_xticklabels(fractions)\n",
" ax.set_xlabel('Fraction of Labeled Samples')\n",
" ax.set_ylabel('AUROC')\n",
" ax.set_title('Model AUROC Comparison')\n",
" ax.set_ylim(0, 1)\n",
" ax.legend(loc='lower right')\n",
" ax.grid(True)\n",
"\n",
Expand Down Expand Up @@ -461,7 +492,7 @@
"\n",
"Before we dive deeper into the chosen weak labelling technique and its impact on the model performance, we will first decide whether we will train our model via transfer learning or fine-tuning.\n",
"\n",
"For this we will train the model using the nested splits. "
"For this we will train the model using the nested splits on both techniques and compare the results. The results are stored in the `data/eval` directory as `.json` files."
]
},
{
Expand All @@ -480,67 +511,108 @@
"with open(f'{MODEL_DIR}/supervised/transfer_nested/eval_results.json') as file:\n",
" transfer_nested_data = json.load(file)\n",
"\n",
"\n",
"plot_model_performance([transfer_nested_data], ['Transfer Learning'], baseline_data, metrics=relevant_metrics)\n",
"\n"
],
"id": "3bd5576049fc6c83",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The results show that the transfer learning model barely outperforms the baseline model. This indicates that the pretrained model's knowledge is not sufficient to achieve high performance on the sentiment analysis task. ",
"id": "5d0cb8c4c5b514f2"
},
{
"cell_type": "markdown",
"id": "23d7a8e1",
"metadata": {},
"source": "### 5.2 Fine-tuning"
"source": [
"### 5.2 Fine-tuning\n",
"To identify if fine-tuning the model can improve the performance, we will train the model using the nested splits and compare the results. \n",
"\n",
"For the fine-tuning we are using the recommended hyperparameters from the Hugging Face documentation. "
]
},
{
"metadata": {},
"cell_type": "code",
"source": [
"with open(f'{MODEL_DIR}/supervised/finetune_nested/eval_results.json') as file:\n",
" finetune_nested_data = json.load(file)\n"
" finetune_nested_data = json.load(file)\n",
"\n",
"plot_model_performance([finetune_nested_data], ['Fine-tuning'], baseline_data, metrics=relevant_metrics)\n"
],
"id": "1871656e5dc97ef9",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "The results show that fine-tuning outperforms both the baseline model and the transfer learning model. We can also see that after using 75% of the labelled data (750 labelled samples) the model stagnates in its performance. This indicates that the model has reached its capacity to learn from the data and adding more data does not substantially improve the performance.\n",
"id": "d760fdb577139928"
},
{
"cell_type": "markdown",
"id": "22cfcf6e",
"metadata": {},
"source": [
"## 6. Semi-Supervised Learning Performance\n",
"Semi-supervised learning techniques, specifically K-Nearest Neighbors (KNN) and Logistic Regression (LogReg), are employed to generate weak labels for the unlabeled samples. The impact of the number of labeled samples and weak labeling strategies on model performance is analyzed and presented."
"After we established that fine-tuning is the best approach for training the model, we will now evaluate the performance of the semi-supervised learning techniques. We will compare the performance of the fine-tuned model with weak labels generated using different weak labelling strategies. \n",
"\n",
"The nested split logic above is used, with the small difference that each split contains the fully labeled data. This means that the nested split is applied to the weak labels and then concatenated with the fully labeled data. "
]
},
{
"cell_type": "markdown",
"id": "1c53e734",
"metadata": {},
"source": "### 6.1 Logistic Regression (LogReg) Weak Labelling"
},
{
"metadata": {},
"cell_type": "code",
"source": [
"### 6.1 Logistic Regression (LogReg)\n",
"The performance of LogReg-based weak labeling using different amounts of labeled samples is presented and discussed."
]
"with open(f'{MODEL_DIR}/semi-supervised/finetune_nested/eval_results.json') as file:\n",
" logreg_nested_data = json.load(file)\n",
" \n",
"plot_model_performance([logreg_nested_data], ['LogReg Weak-Labelling'], finetune_nested_data[\"1.0\"], metrics=relevant_metrics, baseline_name='Fine-tuning 100% (Fully Labeled)')\n"
],
"id": "6d5220fe8adfcba9",
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Adding weak labels to the dataset has a significant impact on the model performance. With only an addition 25% ",
"id": "79dc94ed9017120e"
},
{
"cell_type": "markdown",
"id": "e3d33785",
"metadata": {},
"source": [
"## 7. Learning Curve Analysis\n",
"The learning curve, plotting the model performance against varying numbers of labeled samples for each technique (supervised and semi-supervised), is presented and analyzed. The focus is on the range with few labeled samples, and the practical implications of the results are discussed."
]
"source": "## 7. Learning Curve Analysis"
},
{
"cell_type": "code",
"id": "eabcf731",
"metadata": {},
"source": [
"# Code for generating the learning curve plot"
"# Plot all results\n",
"plot_model_performance([transfer_nested_data, finetune_nested_data, logreg_nested_data], ['Transfer Learning', 'Fine-tuning', 'LogReg Weak-Labelling'], baseline_data, metrics=relevant_metrics)"
],
"outputs": [],
"execution_count": null
},
{
"metadata": {},
"cell_type": "markdown",
"source": "Using the ",
"id": "ea3d85c7caa6c502"
},
{
"cell_type": "markdown",
"id": "7ce2752d",
Expand All @@ -550,6 +622,17 @@
"A thorough analysis of the results is conducted, comparing the baseline model, supervised learning techniques, and semi-supervised learning techniques. The impact of different weak labeling strategies and training data sizes on model performance is evaluated. The best approach for the chosen dataset is determined, emphasizing the models that achieve acceptable performance with few manually annotated samples."
]
},
{
"metadata": {},
"cell_type": "code",
"source": [
"# Plot all results\n",
"plot_model_auroc([transfer_nested_data, finetune_nested_data, logreg_nested_data], ['Transfer Learning', 'Fine-tuning', 'LogReg Weak-Labelling'], baseline_data)"
],
"id": "76592bb8196721ae",
"outputs": [],
"execution_count": null
},
{
"cell_type": "markdown",
"id": "45316041",
Expand All @@ -567,15 +650,6 @@
"## 10. Conclusion and Future Directions\n",
"The key findings, insights, and potential implications of the sentiment analysis mini-challenge are summarized. The effectiveness of weak supervision techniques in reducing the need for manual annotation while maintaining acceptable model performance is discussed. Future directions for research and improvements are outlined."
]
},
{
"cell_type": "markdown",
"id": "3595d590",
"metadata": {},
"source": [
"## 11. AI Tool Usage Assessment\n",
"The use of ChatGPT or other AI tools throughout the mini-challenge is documented and assessed. The tasks for which they were used, the prompting strategies employed, and their contribution to solving the problem and acquiring new skills are specified."
]
}
],
"metadata": {
Expand Down
Loading

0 comments on commit 29bb03f

Please sign in to comment.