-
Notifications
You must be signed in to change notification settings - Fork 507
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[MHLO] bert-tiny and resnet18 example from torchscript to mhlo (#1266)
Co-authored-by: Bairen Yi <yibairen.byron@bytedance.com> Co-authored-by: Jiawei Wu <xremold@gmail.com> Co-authored-by: Tianyou Guo <tianyou.gty@alibaba-inc.com> Co-authored-by: Xu Yan <yancey.yx@alibaba-inc.com> Co-authored-by: Ziheng Jiang <ziheng.jiang@bytedance.com> Co-authored-by: Vremold <xremold@gamil.com>
- Loading branch information
1 parent
2374098
commit 1106b9a
Showing
2 changed files
with
38 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import torch | ||
import torchvision.models as models | ||
import torch_mlir | ||
|
||
model = models.resnet18(pretrained=True) | ||
model.eval() | ||
data = torch.randn(2,3,200,200) | ||
out_mhlo_mlir_path = "./resnet18_mhlo.mlir" | ||
|
||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) | ||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: | ||
outf.write(str(module)) | ||
|
||
print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import torch | ||
import torch_mlir | ||
|
||
from transformers import BertForMaskedLM | ||
|
||
# Wrap the bert model to avoid multiple returns problem | ||
class BertTinyWrapper(torch.nn.Module): | ||
def __init__(self) -> None: | ||
super().__init__() | ||
self.bert = BertForMaskedLM.from_pretrained("prajjwal1/bert-tiny", return_dict=False) | ||
|
||
def forward(self, data): | ||
return self.bert(data)[0] | ||
|
||
model = BertTinyWrapper() | ||
model.eval() | ||
data = torch.randint(30522, (2, 128)) | ||
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" | ||
|
||
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) | ||
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: | ||
outf.write(str(module)) | ||
|
||
print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}") |