-
Notifications
You must be signed in to change notification settings - Fork 11
/
onnx_export.py
49 lines (40 loc) · 1.98 KB
/
onnx_export.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
import argparse
from sapiens_inference.common import TaskType, download_hf_model
from sapiens_inference import SapiensSegmentationType, SapiensNormalType, SapiensDepthType
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_type_dict = {
"seg03b": (SapiensSegmentationType.SEGMENTATION_03B, TaskType.SEG),
"seg06b": (SapiensSegmentationType.SEGMENTATION_06B, TaskType.SEG),
"seg1b": (SapiensSegmentationType.SEGMENTATION_1B, TaskType.SEG),
"normal03b": (SapiensNormalType.NORMAL_03B, TaskType.NORMAL),
"normal06b": (SapiensNormalType.NORMAL_06B, TaskType.NORMAL),
"normal1b": (SapiensNormalType.NORMAL_1B, TaskType.NORMAL),
"normal2b": (SapiensNormalType.NORMAL_2B, TaskType.NORMAL),
"depth03b": (SapiensDepthType.DEPTH_03B, TaskType.DEPTH),
"depth06b": (SapiensDepthType.DEPTH_06B, TaskType.DEPTH),
"depth1b": (SapiensDepthType.DEPTH_1B, TaskType.DEPTH),
"depth2b": (SapiensDepthType.DEPTH_2B, TaskType.DEPTH)
}
@torch.no_grad()
def export_model(model_name: str, filename: str):
type, task_type = model_type_dict[model_name]
path = download_hf_model(type.value, TaskType.SEG)
model = torch.jit.load(path)
model = model.eval().to(device).to(torch.float32)
input = torch.randn(1, 3, 1024, 768, dtype=torch.float32, device=device) # Only this size seems to work well
torch.onnx.export(model,
input,
filename,
export_params=True,
do_constant_folding=True,
opset_version=14,
input_names=["input"],
output_names=["output"])
def get_parser():
parser = argparse.ArgumentParser(description="Export Sapiens models to ONNX")
parser.add_argument("model_name", type=str, choices=model_type_dict.keys(), help="Model type to export")
return parser
if __name__ == "__main__":
args = get_parser().parse_args()
export_model(args.model, f"{args.model}.onnx")