From c443417686471c2395e8fd08a2d03f06b1c2d3a8 Mon Sep 17 00:00:00 2001 From: toby petty Date: Mon, 18 Jul 2022 22:21:25 -0400 Subject: [PATCH] Adding notebook for transfer learning piece. --- Amazon_product_data_transfer_learning.ipynb | 1877 +++++++++++++++++++ 1 file changed, 1877 insertions(+) create mode 100644 Amazon_product_data_transfer_learning.ipynb diff --git a/Amazon_product_data_transfer_learning.ipynb b/Amazon_product_data_transfer_learning.ipynb new file mode 100644 index 0000000..5fb44e4 --- /dev/null +++ b/Amazon_product_data_transfer_learning.ipynb @@ -0,0 +1,1877 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "view-in-github", + "colab_type": "text" + }, + "source": [ + "\"Open" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YXAL6gpkijkz", + "colab": { + "base_uri": "https://localhost:8080/" + }, + "outputId": "3ebd7658-10b0-4ba8-d5c4-99fb95044f4b" + }, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[K |████████████████████████████████| 4.4 MB 5.1 MB/s \n", + "\u001b[K |████████████████████████████████| 6.6 MB 69.6 MB/s \n", + "\u001b[K |████████████████████████████████| 596 kB 78.2 MB/s \n", + "\u001b[K |████████████████████████████████| 101 kB 10.5 MB/s \n", + "\u001b[?25h" + ] + } + ], + "source": [ + "!pip install transformers --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Kd0xo5RHVbg-" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from sklearn.metrics import classification_report\n", + "from sklearn.model_selection import train_test_split\n", + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "from transformers import BertTokenizer, TFBertForSequenceClassification " + ] + }, + { + "cell_type": "code", + "source": [ + "# Code for helping save models to GDrive after training:\n", + "\n", + "import datetime\n", + "import os\n", + "\n", + "from google.colab import drive\n", + "\n", + "# Mount Google Drive:\n", + "drive.mount(\"/content/gdrive\")\n", + "\n", + "# Directory where models will be stored in GDrive:\n", + "MODEL_DIR = \"/content/gdrive/MyDrive/models\"\n", + "\n", + "# Make the directories for storing results if they don't exist yet:\n", + "if not os.path.exists(MODEL_DIR):\n", + " os.mkdir(MODEL_DIR)\n", + "\n", + "\n", + "def gdrive_save_dir(*subdir: str, model_name: str = \"test_model\"): \n", + " \"\"\"Create timestamped directory in GDrive for storing checkpoints or models.\n", + " \n", + " Args:\n", + " subdir: optional subdirectories of the main model directory\n", + " (e.g. `checkpoints`, `final_model`, etc.)\n", + " model_name: main name for directory specifying the model being saved.\n", + " \"\"\"\n", + " model_dir = f\"{MODEL_DIR}/{model_name}\"\n", + " if not os.path.exists(model_dir):\n", + " os.mkdir(model_dir)\n", + " for s in subdir:\n", + " model_dir = f\"{model_dir}/{s}\"\n", + " if not os.path.exists(model_dir):\n", + " os.mkdir(model_dir)\n", + " now = datetime.datetime.now()\n", + " now_str = now.strftime(\"%Y_%m_%d__%H_%M_%S\")\n", + " dir_path = f\"{model_dir}/{now_str}\"\n", + " os.mkdir(dir_path)\n", + " print(f\"Created checkpoint dir: {dir_path}\")\n", + " return dir_path\n", + "\n", + "\n", + "gdrive_save_dir(\"checkpoints\", model_name = \"test_model\")" + ], + "metadata": { + "id": "Ms69Utx-0kHS", + "outputId": "4e1d4e86-e6c1-49dd-c03f-46ac7bf770fa", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 70 + } + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Mounted at /content/gdrive\n", + "Created checkpoint dir: /content/gdrive/MyDrive/models/test_model/checkpoints/2022_07_17__20_21_56\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "'/content/gdrive/MyDrive/models/test_model/checkpoints/2022_07_17__20_21_56'" + ], + "application/vnd.google.colaboratory.intrinsic+json": { + "type": "string" + } + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JVa_1CPEIMv0" + }, + "source": [ + "## Create train and test data" + ] + }, + { + "cell_type": "code", + "source": [ + "# Using the datasets created in a separate notebook and saved to Github:\n", + "train_url = \"https://raw.githubusercontent.com/toby-p/w266-final-project/main/data/amazon/train.csv\"\n", + "test_url = \"https://raw.githubusercontent.com/toby-p/w266-final-project/main/data/amazon/test.csv\"\n", + "val_url = \"https://raw.githubusercontent.com/toby-p/w266-final-project/main/data/amazon/val.csv\"\n", + "\n", + "amazon_train = pd.read_csv(train_url, encoding=\"latin1\")\n", + "amazon_test = pd.read_csv(test_url, encoding=\"latin1\")\n", + "amazon_val = pd.read_csv(val_url, encoding=\"latin1\")" + ], + "metadata": { + "id": "DDJUbITGk83S" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "amazon_train.tail()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "-IDQWfD-oet_", + "outputId": "0c639ca9-acbf-4188-8a57-3aa906406465" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " id sentence label\n", + "7995 89260 Easy access off the 101 lots of parking in the... 0\n", + "7996 62116 Meh. I went in for some accessories and a part... 0\n", + "7997 11115 Worst customer service ever. I called the stor... 0\n", + "7998 11885 I had my Canon Rebel T1i repaired after I drop... 0\n", + "7999 53295 Great store a little short on boys youth sizes... 0" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsentencelabel
799589260Easy access off the 101 lots of parking in the...0
799662116Meh. I went in for some accessories and a part...0
799711115Worst customer service ever. I called the stor...0
799811885I had my Canon Rebel T1i repaired after I drop...0
799953295Great store a little short on boys youth sizes...0
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 5 + } + ] + }, + { + "cell_type": "code", + "source": [ + "amazon_test.tail()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 206 + }, + "id": "UbGZScrEof99", + "outputId": "ecd44fe7-805c-403d-bf9b-242995a2ab1f" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "execute_result", + "data": { + "text/plain": [ + " id sentence label\n", + "1995 63354 Big sale this week. All sort of little gadets ... 0\n", + "1996 45423 The new owner and management are great. I didn... 0\n", + "1997 12024 Came here to check out their Patio Furniture. ... 0\n", + "1998 89218 I brought in a flash drive with a 3-page docum... 0\n", + "1999 45672 Super helpful. Taught me exactly how to gel st... 0" + ], + "text/html": [ + "\n", + "
\n", + "
\n", + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsentencelabel
199563354Big sale this week. All sort of little gadets ...0
199645423The new owner and management are great. I didn...0
199712024Came here to check out their Patio Furniture. ...0
199889218I brought in a flash drive with a 3-page docum...0
199945672Super helpful. Taught me exactly how to gel st...0
\n", + "
\n", + " \n", + " \n", + " \n", + "\n", + " \n", + "
\n", + "
\n", + " " + ] + }, + "metadata": {}, + "execution_count": 6 + } + ] + }, + { + "cell_type": "code", + "source": [ + "amazon_val.tail()" + ], + "metadata": { + "id": "37a7puk45rdg" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "x_train_full = amazon_train[\"reviewText\"]\n", + "y_train_full = amazon_train[\"label\"]\n", + "x_val = amazon_val[\"reviewText\"]\n", + "y_val = amazon_val[\"label\"]\n", + "x_test = amazon_test[\"reviewText\"]\n", + "y_test = amazon_test[\"label\"]" + ], + "metadata": { + "id": "vvzKO4Wj4f6c" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "print(f\"Shape x_train: {x_train.shape}\")\n", + "print(f\"Shape x_val: {x_val.shape}\")\n", + "print(f\"Shape x_test: {x_test.shape}\")\n", + "print(f\"Shape y_train: {y_train.shape}\")\n", + "print(f\"Shape y_val: {y_val.shape}\")\n", + "print(f\"Shape y_test: {y_test.shape}\")" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "GpZ1130U4yLK", + "outputId": "03d8afd0-4b78-4793-e18c-73d4fc76d27b" + }, + "execution_count": null, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Shape X_train: (7200,)\n", + "Shape X_valid: (800,)\n", + "Shape y_train: (7200,)\n", + "Shape y_val: (800,)\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "## Tokenize inputs" + ], + "metadata": { + "id": "HcyU1oVWakMM" + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "KYbFcFevWm_2", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 113, + "referenced_widgets": [ + "f11fa95cfbf944e1802041360bd570c4", + "8ad6f5ce20b54b1998e7564b75110c2c", + "915c00323a294011936f9df042b23a3c", + "192c6ad0b1f841d59404af12315dcbe0", + "3dcc93af7a9942ed90bcf271fcfc9733", + "85236db36963455d95b6064d357827a0", + "4970a1c799b5404a95dbdaf21381321c", + "7c1b3c45dbc64574a8d58857b377c313", + "3bc7a7d4fb76484db3c3a8c18560ae59", + "ec0b04c10bc4422e89eda112acf5ce94", + "9a58e1b736bb4061b0487aabd4787e00", + "3b718d0daf134614b0a5acbdb11a36d7", + "62d829b69bde49e99c18475c37a05653", + "86f761f7538d47d8a1511dcc6f42d0f6", + "bab2cbe5504943b495cc0b08575554eb", + "4034c1afab21422da245d768659bad36", + "5b63cc97705d4f24909deed2cf1d9fe7", + "996ebf10aa824f4a80105f6cd3ff6a45", + "34859c47432e47568c30a0a4bfa07009", + "e28e1c04929a495e8d0d9d04c434a26f", + "1b752a32ae58425e96e36d51ac6480f4", + "d9a7571d52f540b682250a9a0be70828", + "eba887bffb0d4d9b842e5159dfac955f", + "1e302ee6f1754b65a33243a32d18949d", + "925454fc09394de5a076c4aae58964d5", + "3d688a7ceeea4313b942ec99c4b56b16", + "0481df8d4eec4acf8ce79bdcef692d65", + "ddec1641b47e4f849e8962df58709c9f", + "17d6d1d0d479437c9b0658e3e586b68d", + "dc61574a6ac14fbf9660c3abf4fd4956", + "66cf5b29c3654b8ab290a0faf1e2d1f3", + "425ba70930c74f9ab4fe8a709aed6376", + "8f2b1dd0f4b44cdba6ee1306be4ed3f5" + ] + }, + "outputId": "a1dc2c7e-24e3-4ec4-98ec-2c08e2533a9a" + }, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "Downloading: 0%| | 0.00/226k [00:00