Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How did you convert the model to pytorch mobile? #1

Closed
ramtiin opened this issue Apr 14, 2022 · 2 comments
Closed

How did you convert the model to pytorch mobile? #1

ramtiin opened this issue Apr 14, 2022 · 2 comments

Comments

@ramtiin
Copy link

ramtiin commented Apr 14, 2022

Hi,

May I know how you converted the original python model to the mobile (lite) version?

@juanjaho
Copy link
Owner

juanjaho commented Apr 15, 2022

Hi @ramtiin ,

Using the codes attached below: both are done by tracing the data.
- For ArcaneGAN model, I have converted .jit to .ptl
- As for the AnimeGAN model, .pt model is converted to .ptl
- https://github.com/bryandlee/animegan2-pytorch also has the detailed documentation on how TensorFlow weights were converted to .pt

convert_model.zip

.jit to .ptl (ArcaneGAN model)

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from torchvision import transforms
from PIL import Image

model = torch.jit.load("model.jit", map_location="cpu").to('cpu').float().eval().cpu()
image = Image.open("test.jpg").convert("RGB")


means = [0.485, 0.456, 0.406]
stds = [0.229, 0.224, 0.225]

img_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(means, stds)])

with torch.no_grad():
    image = img_transforms(image)[None, ...].cpu()
    traced_script_module = torch.jit.trace(model, image)
    optimized_traced_model = optimize_for_mobile(traced_script_module)
    optimized_traced_model._save_for_lite_interpreter("model.ptl")

.pt to .ptl (AnimeGAN model)

import torch
from torch.utils.mobile_optimizer import optimize_for_mobile
from PIL import Image
from torchvision.transforms.functional import to_tensor, to_pil_image

from model import Generator

model = Generator()
model.load_state_dict(torch.load("model.pt", map_location="cpu"))
model.to('cpu').eval()

image = Image.open("test.jpg").convert("RGB")


with torch.no_grad():
    image = to_tensor(image).unsqueeze(0) * 2 - 1
    traced_script_module = torch.jit.trace(model, image)
    optimized_traced_model = optimize_for_mobile(traced_script_module)
    optimized_traced_model._save_for_lite_interpreter("model.ptl")

@ramtiin
Copy link
Author

ramtiin commented Apr 18, 2022

Thank you so much!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants