Skip to content

Commit

Permalink
Ready for review
Browse files Browse the repository at this point in the history
  • Loading branch information
guberti committed May 20, 2022
1 parent c330e4b commit e7f3879
Showing 1 changed file with 12 additions and 9 deletions.
21 changes: 12 additions & 9 deletions gallery/how_to/work_with_microtvm/micro_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
# .. image:: https://raw.githubusercontent.com/guberti/web-data/micro-train-tutorial-data/images/utilities/colab_button.png
# :align: center
# :target: https://colab.research.google.com/github/guberti/tvm-site/blob/asf-site/docs/_downloads/a7c7ea4b5017ae70db1f51dd8e6dcd82/micro_train.ipynb
# :width: 600px
#
# Motivation
# ----------
Expand Down Expand Up @@ -98,7 +99,8 @@
# ^^^^^^^^^^^^^^^^^^^^^
# We need to pick a directory where our image datasets, trained model, and eventual Arduino sketch
# will all live. If running on Google Colab, we'll save everything in ``/root`` (aka ``~``) but you'll
# probably want to store it elsewhere if running locally.
# probably want to store it elsewhere if running locally. Note that this variable only affects Python
# scripts - you'll have to adjust the Bash commands too.

import os

Expand Down Expand Up @@ -172,6 +174,7 @@
shutil.move(f"{FOLDER}/images/cars_train", f"{FOLDER}/images/target")
shutil.move(f"{FOLDER}/images/val2017", f"{FOLDER}/images/random")

######################################################################
# Loading the Data
# ----------------
# Currently, our data is stored on-disk as JPG files of various sizes. To train with it, we'll have
Expand Down Expand Up @@ -230,8 +233,8 @@
plt.axis("off")

######################################################################
# What's Inside Our Dataset?
# ^^^^^^^^^^^^^^^^^^^^^^^^^^
# Validating our Accuracy
# ^^^^^^^^^^^^^^^^^^^^^^^
# While developing our model, we'll often want to check how accurate it is (e.g. to see if it
# improves during training). How do we do this? We could just train it on *all* of the data, and
# then ask it to classify that same data. However, our model could cheat by just memorizing all of
Expand Down Expand Up @@ -341,7 +344,7 @@
loss="categorical_crossentropy",
metrics=["accuracy"],
)
# model.fit(train_dataset, validation_data=validation_dataset, epochs=3, verbose=2)
model.fit(train_dataset, validation_data=validation_dataset, epochs=3, verbose=2)

######################################################################
# Quantization
Expand Down Expand Up @@ -506,7 +509,7 @@ def representative_dataset():
# .. code-block:: bash
#
# %%bash
# mkdir -p /root/tests
# mkdir -p ~/tests
# curl "https://i.imgur.com/JBbEhxN.png" -o ~/tests/car_224.png
# convert ~/tests/car_224.png -resize 64 ~/tests/car_64.png
# stream ~/tests/car_64.png ~/tests/car.raw
Expand Down Expand Up @@ -559,7 +562,7 @@ def representative_dataset():
# subject for a different tutorial. To finish up, we'll first test that our program compiles does
# not throw any compiler errors:

shutil.rmtree("{FOLDER}/models/project/build", ignore_errors=True)
shutil.rmtree(f"{FOLDER}/models/project/build", ignore_errors=True)
arduino_project.build()
print("Compilation succeeded!")

Expand All @@ -572,7 +575,7 @@ def representative_dataset():
# If you're running on Google Colab, you'll have to uncomment the last two lines to download the file
# after writing it.

ZIP_FOLDER = "{FOLDER}/models/project"
ZIP_FOLDER = f"{FOLDER}/models/project"
shutil.make_archive(ZIP_FOLDER, "zip", ZIP_FOLDER)
# from google.colab import files
# files.download(f"{FOLDER}/models/project.zip")
Expand All @@ -587,8 +590,8 @@ def representative_dataset():
assert len(quantized_model) <= 350000 # Exact value depends on quantization

# Assert .tflite and .zip files were written to disk
assert os.path.isfile("{FOLDER}/models/quantized.tflite")
assert os.path.isfile("{FOLDER}/models/project.zip")
assert os.path.isfile(f"{FOLDER}/models/quantized.tflite")
assert os.path.isfile(f"{FOLDER}/models/project.zip")

# Assert MLF file was correctly generated
assert str(mod.executor) == "aot"
Expand Down

0 comments on commit e7f3879

Please sign in to comment.