Skip to content

Commit

Permalink
Fixing export_onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqwang committed Sep 30, 2021
1 parent 06022fd commit f9ab3fa
Showing 1 changed file with 4 additions and 14 deletions.
18 changes: 4 additions & 14 deletions tools/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,14 @@
except ImportError:
onnxsim = None

from yolort import models
from yolort.models import YOLOv5


def get_parser():
parser = argparse.ArgumentParser('CLI tool for exporting ONNX models', add_help=True)

parser.add_argument('--checkpoint_path', type=str, required=True,
help='The path of checkpoint weights')
# Model architecture
parser.add_argument('--arch', choices=['yolov5s', 'yolov5m', 'yolov5l'], default='yolov5s',
help='Model architecture to export')
parser.add_argument('--num_classes', default=80, type=int,
help='The number of classes')
parser.add_argument('--score_thresh', default=0.25, type=float,
help='Score threshold used for postprocessing the detections.')
parser.add_argument('--export_friendly', action='store_true',
Expand Down Expand Up @@ -86,15 +81,10 @@ def cli_main():
assert checkpoint_path.is_file(), f'Not found checkpoint: {checkpoint_path}'

# input data
images = torch.rand(3, args.image_size, args.image_size)
inputs = ([images], )
images = [torch.rand(3, args.image_size, args.image_size)]
inputs = (images, )

model = models.__dict__[args.arch](
num_classes=args.num_classes,
export_friendly=args.export_friendly,
score_thresh=args.score_thresh
)
model.load_from_yolov5(checkpoint_path)
model = YOLOv5.load_from_yolov5(checkpoint_path, score_thresh=args.score_thresh)
model.eval()

# export ONNX models
Expand Down

0 comments on commit f9ab3fa

Please sign in to comment.