Skip to content

Commit

Permalink
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Browse files Browse the repository at this point in the history
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com>
Co-authored-by: Jiawei Wu <xremold@gmail.com>
Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com>
Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com>
Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com>
Co-authored-by: Vremold <xremold@gamil.com>
  • Loading branch information
7 people authored Aug 23, 2022
1 parent 2374098 commit 1106b9a
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
14 changes: 14 additions & 0 deletions examples/torchscript_mhlo_backend_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir

model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_mhlo_mlir_path = "./resnet18_mhlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}")
24 changes: 24 additions & 0 deletions examples/torchscript_mhlo_backend_tinybert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch_mlir

from transformers import BertForMaskedLM

# Wrap the bert model to avoid multiple returns problem
class BertTinyWrapper(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False)

def forward(self, data):
return self.bert(data)[0]

model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")

0 comments on commit 1106b9a

Please sign in to comment.