Skip to content

Commit

Permalink
fix: keras model changed to pytorch format (apache#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
srkreddy1238 committed Nov 11, 2024
1 parent b130960 commit f4977d1
Showing 1 changed file with 23 additions and 9 deletions.
32 changes: 23 additions & 9 deletions gallery/how_to/deploy_models/deploy_model_on_adreno_tvmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
==========================================================
**Author**: Siva Rama Krishna
This article is a step-by-step tutorial to deploy pretrained Keras resnet50 model on Adreno™.
This article is a step-by-step tutorial to deploy pretrained PyTorch resnet50 model on Adreno™.
Besides that, you should have TVM built for Android.
See the following instructions on how to build it and setup RPC environment.
Expand Down Expand Up @@ -71,16 +71,27 @@
)

#######################################################################
# Make a Keras Resnet50 Model
# Make a PyTorch Resnet50 Model
# ---------------------------

from tensorflow.keras.applications.resnet50 import ResNet50
import torch
import torchvision.models as models

tmp_path = utils.tempdir()
model_file_name = tmp_path.relpath("resnet50.h5")
# Load the ResNet50 model pre-trained on ImageNet
model = models.resnet50(pretrained=True)

model = ResNet50(include_top=True, weights="imagenet", input_shape=(224, 224, 3), classes=1000)
model.save(model_file_name)
# Set the model to evaluation mode
model.eval()

# Define the input shape
dummy_input = torch.randn(1, 3, 224, 224)

# Trace the model
traced_model = torch.jit.trace(model, dummy_input)

# Save the traced model
model_file_name = "resnet50_traced.pt"
traced_model.save(model_file_name)


#######################################################################
Expand All @@ -89,7 +100,10 @@
# Convert a model from any framework to a tvm relay module.
# tvmc.load supports models from any framework (like tensorflow saves_model, onnx, tflite ..etc) and auto detects the filetype.

tvmc_model = tvmc.load(model_file_name)
input_shape = (1, 3, 224, 224) # Batch size, channels, height, width

# Load the TorchScript model with TVMC
tvmc_model = tvmc.load(model_file_name, shape_dict={"input": input_shape}, model_format="pytorch")

print(tvmc_model.mod)

Expand Down Expand Up @@ -158,7 +172,7 @@
# Altrernatively, we can save the compilation output and save it as a TVMCPackage.
# This way avoids loading of compiled module without compiling again.
target = target + ", clml"
pkg_path = tmp_path.relpath("keras-resnet50.tar")
pkg_path = tmp_path.relpath("torch-resnet50.tar")
tvmc.compile(
tvmc_model,
target=target,
Expand Down

0 comments on commit f4977d1

Please sign in to comment.