From e7f3879463ed55b5496ef86b53970987c09b4fb2 Mon Sep 17 00:00:00 2001 From: Gavin Uberti Date: Mon, 2 May 2022 16:45:09 -0400 Subject: [PATCH] Ready for review --- .../how_to/work_with_microtvm/micro_train.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/gallery/how_to/work_with_microtvm/micro_train.py b/gallery/how_to/work_with_microtvm/micro_train.py index d1ee33f32b37c..83d7d3c870835 100644 --- a/gallery/how_to/work_with_microtvm/micro_train.py +++ b/gallery/how_to/work_with_microtvm/micro_train.py @@ -36,6 +36,7 @@ # .. image:: https://raw.githubusercontent.com/guberti/web-data/micro-train-tutorial-data/images/utilities/colab_button.png # :align: center # :target: https://colab.research.google.com/github/guberti/tvm-site/blob/asf-site/docs/_downloads/a7c7ea4b5017ae70db1f51dd8e6dcd82/micro_train.ipynb +# :width: 600px # # Motivation # ---------- @@ -98,7 +99,8 @@ # ^^^^^^^^^^^^^^^^^^^^^ # We need to pick a directory where our image datasets, trained model, and eventual Arduino sketch # will all live. If running on Google Colab, we'll save everything in ``/root`` (aka ``~``) but you'll -# probably want to store it elsewhere if running locally. +# probably want to store it elsewhere if running locally. Note that this variable only affects Python +# scripts - you'll have to adjust the Bash commands too. import os @@ -172,6 +174,7 @@ shutil.move(f"{FOLDER}/images/cars_train", f"{FOLDER}/images/target") shutil.move(f"{FOLDER}/images/val2017", f"{FOLDER}/images/random") +###################################################################### # Loading the Data # ---------------- # Currently, our data is stored on-disk as JPG files of various sizes. To train with it, we'll have @@ -230,8 +233,8 @@ plt.axis("off") ###################################################################### -# What's Inside Our Dataset? -# ^^^^^^^^^^^^^^^^^^^^^^^^^^ +# Validating our Accuracy +# ^^^^^^^^^^^^^^^^^^^^^^^ # While developing our model, we'll often want to check how accurate it is (e.g. to see if it # improves during training). How do we do this? We could just train it on *all* of the data, and # then ask it to classify that same data. However, our model could cheat by just memorizing all of @@ -341,7 +344,7 @@ loss="categorical_crossentropy", metrics=["accuracy"], ) -# model.fit(train_dataset, validation_data=validation_dataset, epochs=3, verbose=2) +model.fit(train_dataset, validation_data=validation_dataset, epochs=3, verbose=2) ###################################################################### # Quantization @@ -506,7 +509,7 @@ def representative_dataset(): # .. code-block:: bash # # %%bash -# mkdir -p /root/tests +# mkdir -p ~/tests # curl "https://i.imgur.com/JBbEhxN.png" -o ~/tests/car_224.png # convert ~/tests/car_224.png -resize 64 ~/tests/car_64.png # stream ~/tests/car_64.png ~/tests/car.raw @@ -559,7 +562,7 @@ def representative_dataset(): # subject for a different tutorial. To finish up, we'll first test that our program compiles does # not throw any compiler errors: -shutil.rmtree("{FOLDER}/models/project/build", ignore_errors=True) +shutil.rmtree(f"{FOLDER}/models/project/build", ignore_errors=True) arduino_project.build() print("Compilation succeeded!") @@ -572,7 +575,7 @@ def representative_dataset(): # If you're running on Google Colab, you'll have to uncomment the last two lines to download the file # after writing it. -ZIP_FOLDER = "{FOLDER}/models/project" +ZIP_FOLDER = f"{FOLDER}/models/project" shutil.make_archive(ZIP_FOLDER, "zip", ZIP_FOLDER) # from google.colab import files # files.download(f"{FOLDER}/models/project.zip") @@ -587,8 +590,8 @@ def representative_dataset(): assert len(quantized_model) <= 350000 # Exact value depends on quantization # Assert .tflite and .zip files were written to disk -assert os.path.isfile("{FOLDER}/models/quantized.tflite") -assert os.path.isfile("{FOLDER}/models/project.zip") +assert os.path.isfile(f"{FOLDER}/models/quantized.tflite") +assert os.path.isfile(f"{FOLDER}/models/project.zip") # Assert MLF file was correctly generated assert str(mod.executor) == "aot"