From 32800bcbd7f05ad8c9fa85fdeabede62c1a3c12e Mon Sep 17 00:00:00 2001 From: MohamedAliRashad Date: Sat, 22 May 2021 19:27:18 +0200 Subject: [PATCH 1/3] Add list_models method to know available models --- aitextgen/aitextgen.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/aitextgen/aitextgen.py b/aitextgen/aitextgen.py index 6f81636..c5287a2 100644 --- a/aitextgen/aitextgen.py +++ b/aitextgen/aitextgen.py @@ -72,6 +72,7 @@ class aitextgen: tokenizer = None vocab_file = os.path.join(STATIC_PATH, "gpt2_vocab.json") merges_file = os.path.join(STATIC_PATH, "gpt2_merges.txt") + models = ["124M", "355M", "774M", "1558M"] bos_token = "<|endoftext|>" eos_token = "<|endoftext|>" unk_token = "<|endoftext|>" @@ -123,12 +124,7 @@ def __init__( if not os.path.isfile( os.path.join(cache_dir, f"pytorch_model_{tf_gpt2}.bin") ): - assert tf_gpt2 in [ - "124M", - "355M", - "774M", - "1558M", - ], "Invalid TensorFlow GPT-2 model size." + assert tf_gpt2 in self.models, "Invalid TensorFlow GPT-2 model size." logger.info( f"Downloading the {tf_gpt2} GPT-2 TensorFlow weights/config " @@ -863,3 +859,6 @@ def __repr__(self) -> str: num_params_m = int(sum(p.numel() for p in self.model.parameters()) / 10 ** 6) model_name = type(self.model.config).__name__.replace("Config", "") return f"{model_name} loaded with {num_params_m}M parameters." + + def list_models(self) -> None: + print("\n".join(self.models)) \ No newline at end of file From aff9acea0e3c6b8cad9372fb65e862801f322002 Mon Sep 17 00:00:00 2001 From: MohamedAliRashad Date: Fri, 28 May 2021 21:51:36 +0200 Subject: [PATCH 2/3] Add training script on colab --- train_aitextgen.ipynb | 967 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 967 insertions(+) create mode 100644 train_aitextgen.ipynb diff --git a/train_aitextgen.ipynb b/train_aitextgen.ipynb new file mode 100644 index 0000000..7230602 --- /dev/null +++ b/train_aitextgen.ipynb @@ -0,0 +1,967 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "train_aitextgen.ipynb", + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU", + "widgets": { + "application/vnd.jupyter.widget-state+json": { + "2130bcc715494effafebd6e9a0de9527": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_cd5f9253d1af45e79f7f0771617bafdd", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_75fba03efbb14c7ba437bfe54c7a17ef", + "IPY_MODEL_1a4f2172fdc54faebf11411ac5f96da8" + ] + } + }, + "cd5f9253d1af45e79f7f0771617bafdd": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": "row wrap", + "width": "100%", + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": "inline-flex", + "left": null + } + }, + "75fba03efbb14c7ba437bfe54c7a17ef": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_8af2a63fd28d4fb399256768b7fa353f", + "_dom_classes": [], + "description": "100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 5403, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 5403, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_8ae2138d73fd433ab3352ec9a67a6182" + } + }, + "1a4f2172fdc54faebf11411ac5f96da8": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_de7a9b7144d14b58b0a490e5684b932f", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 5403/5403 [00:00<00:00, 19850.46it/s]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_0543802a7e5e4de8906f9afb4f0c7ed7" + } + }, + "8af2a63fd28d4fb399256768b7fa353f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "8ae2138d73fd433ab3352ec9a67a6182": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": "2", + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "de7a9b7144d14b58b0a490e5684b932f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "0543802a7e5e4de8906f9afb4f0c7ed7": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "a875484ff8b04a028b0d9cdc454e47f6": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HBoxModel", + "state": { + "_view_name": "HBoxView", + "_dom_classes": [], + "_model_name": "HBoxModel", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.5.0", + "box_style": "", + "layout": "IPY_MODEL_ef81d88aa21c4df1900e9de96205336a", + "_model_module": "@jupyter-widgets/controls", + "children": [ + "IPY_MODEL_db2571e57bb949bb8e5f134b373e7701", + "IPY_MODEL_1d3f19e3d30e4f2484005f9bab7d9d7f" + ] + } + }, + "ef81d88aa21c4df1900e9de96205336a": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": "row wrap", + "width": "100%", + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": "inline-flex", + "left": null + } + }, + "db2571e57bb949bb8e5f134b373e7701": { + "model_module": "@jupyter-widgets/controls", + "model_name": "FloatProgressModel", + "state": { + "_view_name": "ProgressView", + "style": "IPY_MODEL_9e7e63b1e4c74f0fba28bf385ebd6607", + "_dom_classes": [], + "description": "Loss: 0.216 — Avg: 0.227 — GPU Mem: 11434 MB: 100%", + "_model_name": "FloatProgressModel", + "bar_style": "success", + "max": 1000, + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": 1000, + "_view_count": null, + "_view_module_version": "1.5.0", + "orientation": "horizontal", + "min": 0, + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_4595fc3cb883439db44a149302e2d47c" + } + }, + "1d3f19e3d30e4f2484005f9bab7d9d7f": { + "model_module": "@jupyter-widgets/controls", + "model_name": "HTMLModel", + "state": { + "_view_name": "HTMLView", + "style": "IPY_MODEL_5e28516408e0462fa4dc98fb47dc7d5a", + "_dom_classes": [], + "description": "", + "_model_name": "HTMLModel", + "placeholder": "​", + "_view_module": "@jupyter-widgets/controls", + "_model_module_version": "1.5.0", + "value": " 1000/1000 [20:53<00:00, 1.25s/it]", + "_view_count": null, + "_view_module_version": "1.5.0", + "description_tooltip": null, + "_model_module": "@jupyter-widgets/controls", + "layout": "IPY_MODEL_776c2c30c51f45909424a7e955a2ff52" + } + }, + "9e7e63b1e4c74f0fba28bf385ebd6607": { + "model_module": "@jupyter-widgets/controls", + "model_name": "ProgressStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "ProgressStyleModel", + "description_width": "initial", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "bar_color": null, + "_model_module": "@jupyter-widgets/controls" + } + }, + "4595fc3cb883439db44a149302e2d47c": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": "2", + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + }, + "5e28516408e0462fa4dc98fb47dc7d5a": { + "model_module": "@jupyter-widgets/controls", + "model_name": "DescriptionStyleModel", + "state": { + "_view_name": "StyleView", + "_model_name": "DescriptionStyleModel", + "description_width": "", + "_view_module": "@jupyter-widgets/base", + "_model_module_version": "1.5.0", + "_view_count": null, + "_view_module_version": "1.2.0", + "_model_module": "@jupyter-widgets/controls" + } + }, + "776c2c30c51f45909424a7e955a2ff52": { + "model_module": "@jupyter-widgets/base", + "model_name": "LayoutModel", + "state": { + "_view_name": "LayoutView", + "grid_template_rows": null, + "right": null, + "justify_content": null, + "_view_module": "@jupyter-widgets/base", + "overflow": null, + "_model_module_version": "1.2.0", + "_view_count": null, + "flex_flow": null, + "width": null, + "min_width": null, + "border": null, + "align_items": null, + "bottom": null, + "_model_module": "@jupyter-widgets/base", + "top": null, + "grid_column": null, + "overflow_y": null, + "overflow_x": null, + "grid_auto_flow": null, + "grid_area": null, + "grid_template_columns": null, + "flex": null, + "_model_name": "LayoutModel", + "justify_items": null, + "grid_row": null, + "max_height": null, + "align_content": null, + "visibility": null, + "align_self": null, + "height": null, + "min_height": null, + "padding": null, + "grid_auto_rows": null, + "grid_gap": null, + "max_width": null, + "order": null, + "_view_module_version": "1.2.0", + "grid_template_areas": null, + "object_position": null, + "object_fit": null, + "grid_auto_columns": null, + "margin": null, + "display": null, + "left": null + } + } + } + } + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "gdA1n13fgYcC" + }, + "source": [ + "# Train AITextGen models with GPU/TPU for Free \n", + "\n", + "by [Mohamed Rashad](https://github.com/MohamedAliRashad)\n", + "\n", + "For more about `aitextgen`, you can visit [this GitHub repository](https://github.com/minimaxir/aitextgen).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "O6CCL72mgi9R" + }, + "source": [ + "## Install Dependencies" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "ANQ6CAKcfr3z" + }, + "source": [ + "!pip3 install -q aitextgen\n", + "from aitextgen import aitextgen" + ], + "execution_count": 1, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JYEa2TUWg3P3" + }, + "source": [ + "## Get shakespeare dataset (optional)" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "yiCnQLP5gS90", + "outputId": "e7f67b3a-7bdc-4246-a3f4-5ffd2127826c" + }, + "source": [ + "!git clone https://github.com/ravexina/shakespeare-plays-dataset-scraper.git\n", + "!mv /content/shakespeare-plays-dataset-scraper/shakespeare-db/ /content/db/" + ], + "execution_count": 2, + "outputs": [ + { + "output_type": "stream", + "text": [ + "fatal: destination path 'shakespeare-plays-dataset-scraper' already exists and is not an empty directory.\n", + "mv: cannot stat '/content/shakespeare-plays-dataset-scraper/shakespeare-db/': No such file or directory\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UFKYwnAThVf6" + }, + "source": [ + "## Upload dataset to train on\n", + "\n", + "The data needs to be text in one file." + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "CuxGPVZLhV2M" + }, + "source": [ + "from google.colab import files\n", + "uploaded = files.upload()" + ], + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DhRxd8hshBgM" + }, + "source": [ + "## Enter configurations for training" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "h12v-6sCimJc" + }, + "source": [ + "training_file_path = \"/content/db/Hamlet.txt\" #@param {type:\"string\"}\n", + "gpt_model = '124M' #@param [\"124M\", \"355M\", \"774M\", \"1558M\"]" + ], + "execution_count": 3, + "outputs": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XJiIjj9fO3Kx" + }, + "source": [ + "## Training Script" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000, + "referenced_widgets": [ + "2130bcc715494effafebd6e9a0de9527", + "cd5f9253d1af45e79f7f0771617bafdd", + "75fba03efbb14c7ba437bfe54c7a17ef", + "1a4f2172fdc54faebf11411ac5f96da8", + "8af2a63fd28d4fb399256768b7fa353f", + "8ae2138d73fd433ab3352ec9a67a6182", + "de7a9b7144d14b58b0a490e5684b932f", + "0543802a7e5e4de8906f9afb4f0c7ed7", + "a875484ff8b04a028b0d9cdc454e47f6", + "ef81d88aa21c4df1900e9de96205336a", + "db2571e57bb949bb8e5f134b373e7701", + "1d3f19e3d30e4f2484005f9bab7d9d7f", + "9e7e63b1e4c74f0fba28bf385ebd6607", + "4595fc3cb883439db44a149302e2d47c", + "5e28516408e0462fa4dc98fb47dc7d5a", + "776c2c30c51f45909424a7e955a2ff52" + ] + }, + "id": "kpjapNeihT4k", + "outputId": "df0022ed-e366-4ee7-c191-777afe3af3b1" + }, + "source": [ + "from aitextgen.TokenDataset import TokenDataset\n", + "from aitextgen.tokenizers import train_tokenizer\n", + "from aitextgen import aitextgen\n", + "\n", + "# Train a custom BPE Tokenizer on the downloaded text\n", + "# This will save one file: `aitextgen.tokenizer.json`, which contains the\n", + "# information needed to rebuild the tokenizer.\n", + "train_tokenizer(training_file_path)\n", + "tokenizer_file = \"aitextgen.tokenizer.json\"\n", + "\n", + "# Instantiate aitextgen using the created tokenizer and config\n", + "ai = aitextgen(tf_gpt2=gpt_model, tokenizer_file=tokenizer_file)\n", + "\n", + "# You can build datasets for training by creating TokenDatasets,\n", + "# which automatically processes the dataset with the appropriate size.\n", + "data = TokenDataset(training_file_path, tokenizer_file=tokenizer_file, block_size=64)\n", + "\n", + "# Train the model! It will save pytorch_model.bin periodically and after completion to the `trained_model` folder.\n", + "# On a 2020 8-core iMac, this took ~25 minutes to run.\n", + "ai.train(data, batch_size=64, num_steps=1000, generate_every=500, save_every=500)\n" + ], + "execution_count": 4, + "outputs": [ + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "2130bcc715494effafebd6e9a0de9527", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=5403.0), HTML(value='')), layout=Layout(d…" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "GPU available: True, used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]\n" + ], + "name": "stderr" + }, + { + "output_type": "stream", + "text": [ + "\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a875484ff8b04a028b0d9cdc454e47f6", + "version_minor": 0, + "version_major": 2 + }, + "text/plain": [ + "HBox(children=(FloatProgress(value=0.0, layout=Layout(flex='2'), max=1000.0), HTML(value='')), layout=Layout(d…" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "stream", + "text": [ + "\u001b[1m500 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m500 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + "mwood, wormwood.\n", + "Player Queen\n", + " Since fool no where but in's own house.\n", + " Enter LAERTES and O, forced!\n", + "Ghost\n", + " My hour is almost come, good mother.\n", + " Since my dear Give houch dear lord, sir,--\n", + "HAMLET\n", + "HORATIO\n", + " In my dear Ran's eyeild and mind's eye, tell us, Horatio,--\n", + " He sword, Horatio,--\n", + "HORATIO\n", + " I saw's wonder dear lord, som in the old theme.\n", + " Enter Keep; and all you, seen, see this came be wond, twent, with you see the same before let's goodman defectain!\n", + "HAMLET\n", + "HAMLET\n", + "HAMLET\n", + "HORATIO\n", + " Young For uses, thou art, sir.\n", + "HORATIO\n", + " I saws\n", + " I sawason!\n", + "HORATIO\n", + "HAMLET\n", + "HAMLET\n", + "HORATIO\n", + " Mostrew's a harposen's a ha!\n", + " Marry, sir, twit's goodman defe\n", + "==========\n", + "\u001b[1m1,000 steps reached: saving model to /trained_model\u001b[0m\n", + "\u001b[1m1,000 steps reached: generating sample texts.\u001b[0m\n", + "==========\n", + "ent piece of two brothers.\n", + " Seee, what a grace was seated on this brow;\n", + " Hyperion's curls; the front of Jove himself;\n", + " An eye like Mars, to threaten and command;\n", + " A station like the nonce, a wind away:\n", + " A flowed oreason have not cuffers\n", + " O roys in the nonce.\n", + " Hoonsty she capit in the nonstashy; and pe\n", + " I will their order thus to this retinuy!\n", + " Hy;\n", + " Hys!\n", + " Add, how LAERTES\n", + " Why is faded this to this most deed\n", + " We must be wind ans, HAMLET\n", + " Why such a none wed deeds\n", + " That is this will fit!\n", + " HORATIO\n", + " HAMLET\n", + " Unto, to this matter in my father!\n", + " That, HAMLET\n", + " That to the harated on\n", + "==========\n", + "\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aMFJqQS2PBNj" + }, + "source": [ + "## Load weights and Generate text" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "xzhdXeUBNVow", + "outputId": "2123b4f9-31d9-46b0-bc90-6868e1f2726c" + }, + "source": [ + "# With your trained model, you can reload the model at any time by\n", + "# providing the folder containing the pytorch_model.bin model weights + the config, and providing the tokenizer.\n", + "ai2 = aitextgen(model_folder=\"trained_model\",\n", + " tokenizer_file=\"aitextgen.tokenizer.json\")\n", + "\n", + "ai2.generate(1, prompt=\"Hamlet\")" + ], + "execution_count": 7, + "outputs": [ + { + "output_type": "stream", + "text": [ + "\u001b[1mHamlet\u001b[0ms are horse, when he meant to beg it; might it not?\n", + "HORATIO\n", + " Ay, my lord.\n", + "HAMLET\n", + " Why, e'en so: and now my Lady Worm's; chapless, and\n", + " knocked about the mazzard with a sexton's spade:\n", + " here's fashion, ans ans it ans mere, answer here in the man\n", + " Let memory, I must wife' betwerewering you shall hangers\n", + " Not a ford.\n", + " Not a port\n", + " are another, and pile\n", + " are most in the fire of their addddddddesty would perughted gi would perce would perualents\n", + " and pers\n", + " Py, and their\n", + " and sil, their\n", + " Players\n", + " Players\n", + " their\n", + " and their to sch\n", + " and as fashion will fire\n", + " fit, their in the fire\n", + " fire\n" + ], + "name": "stdout" + } + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qfhuBbTMgY2S" + }, + "source": [ + "## Download trained weights" + ] + }, + { + "cell_type": "code", + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 68 + }, + "id": "VOVubaP9csQ8", + "outputId": "b1fcd4e5-849c-4ad5-ad23-88eb299d7091" + }, + "source": [ + "!zip -r /content/trained_weights.zip /content/trained_model\n", + "\n", + "from google.colab import files\n", + "files.download(\"/content/trained_weights.zip\")" + ], + "execution_count": 8, + "outputs": [ + { + "output_type": "stream", + "text": [ + " adding: content/trained_model/ (stored 0%)\n", + " adding: content/trained_model/config.json (deflated 50%)\n", + " adding: content/trained_model/pytorch_model.bin (deflated 9%)\n" + ], + "name": "stdout" + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "\n", + " async function download(id, filename, size) {\n", + " if (!google.colab.kernel.accessAllowed) {\n", + " return;\n", + " }\n", + " const div = document.createElement('div');\n", + " const label = document.createElement('label');\n", + " label.textContent = `Downloading \"${filename}\": `;\n", + " div.appendChild(label);\n", + " const progress = document.createElement('progress');\n", + " progress.max = size;\n", + " div.appendChild(progress);\n", + " document.body.appendChild(div);\n", + "\n", + " const buffers = [];\n", + " let downloaded = 0;\n", + "\n", + " const channel = await google.colab.kernel.comms.open(id);\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + "\n", + " for await (const message of channel.messages) {\n", + " // Send a message to notify the kernel that we're ready.\n", + " channel.send({})\n", + " if (message.buffers) {\n", + " for (const buffer of message.buffers) {\n", + " buffers.push(buffer);\n", + " downloaded += buffer.byteLength;\n", + " progress.value = downloaded;\n", + " }\n", + " }\n", + " }\n", + " const blob = new Blob(buffers, {type: 'application/binary'});\n", + " const a = document.createElement('a');\n", + " a.href = window.URL.createObjectURL(blob);\n", + " a.download = filename;\n", + " div.appendChild(a);\n", + " a.click();\n", + " div.remove();\n", + " }\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + }, + { + "output_type": "display_data", + "data": { + "application/javascript": [ + "download(\"download_0a74b5db-afe9-49ab-8ebb-9f339de0e291\", \"trained_weights.zip\", 462485305)" + ], + "text/plain": [ + "" + ] + }, + "metadata": { + "tags": [] + } + } + ] + } + ] +} \ No newline at end of file From e03f4d5e0d59cf8fba3c9b4ff7fb8b6f2679acf6 Mon Sep 17 00:00:00 2001 From: MohamedAliRashad Date: Fri, 28 May 2021 21:53:17 +0200 Subject: [PATCH 3/3] Add Colab padge to README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 7e644df..a7c6589 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,7 @@ # aitextgen +[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://github.com/MohamedAliRashad/aitextgen/blob/colab-tutorial/train_aitextgen.ipynb) + A robust Python tool for text-based AI training and generation using [OpenAI's](https://openai.com) [GPT-2](https://openai.com/blog/better-language-models/) and [EleutherAI's](https://www.eleuther.ai) [GPT Neo/GPT-3](https://github.com/EleutherAI/gpt-neo) architecture. aitextgen is a Python package that leverages [PyTorch](https://pytorch.org), [Hugging Face Transformers](https://github.com/huggingface/transformers) and [pytorch-lightning](https://github.com/PyTorchLightning/pytorch-lightning) with specific optimizations for text generation using GPT-2, plus _many_ added features. It is the successor to [textgenrnn](https://github.com/minimaxir/textgenrnn) and [gpt-2-simple](https://github.com/minimaxir/gpt-2-simple), taking the best of both packages: