diff --git a/tutorials/data/wolves.jpg b/tutorials/data/wolves.jpg new file mode 100644 index 000000000..8520bffbd Binary files /dev/null and b/tutorials/data/wolves.jpg differ diff --git a/tutorials/yolo_e2e.py b/tutorials/yolo_e2e.py index 5bc923aea..551931378 100644 --- a/tutorials/yolo_e2e.py +++ b/tutorials/yolo_e2e.py @@ -1,6 +1,8 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT License. +import argparse import numpy +import os from pathlib import Path import onnxruntime_extensions @@ -34,7 +36,7 @@ def add_pre_post_processing_to_yolo(input_model_file: Path, output_model_file: P add_ppp.yolo_detection(input_model_file, output_model_file, "jpg", onnx_opset=18) -def run_inference(onnx_model_file: Path): +def run_inference(onnx_model_file: Path, test_image: Path): import onnxruntime as ort import numpy as np @@ -42,22 +44,36 @@ def run_inference(onnx_model_file: Path): session_options = ort.SessionOptions() session_options.register_custom_ops_library(onnxruntime_extensions.get_library_path()) - image = np.frombuffer(open('../test/data/ppp_vision/wolves.jpg', 'rb').read(), dtype=np.uint8) + image = np.frombuffer(open(f'{test_image}', 'rb').read(), dtype=np.uint8) session = ort.InferenceSession(str(onnx_model_file), providers=providers, sess_options=session_options) inname = [i.name for i in session.get_inputs()] inp = {inname[0]: image} output = session.run(['image_out'], inp)[0] - output_filename = '../test/data/result.jpg' + filename, extension = os.path.splitext(test_image) + output_filename = Path(f'{filename}.out{extension}') open(output_filename, 'wb').write(output) from PIL import Image Image.open(output_filename).show() if __name__ == '__main__': + + script_dir = Path( __file__ ).parent.absolute() + # YOLO version. Tested with 5 and 8. version = 8 - onnx_model_name = Path(f"../test/data/yolov{version}n.onnx") + + parser = argparse.ArgumentParser("Add pre and post processing to the YOLOv8 model.") + parser.add_argument("--onnx_model_path", type=Path, default=f"yolov{version}n.onnx", + help="The location and name of the file to output the ONNX YOLO model.") + parser.add_argument("--test_image", type=Path, default=f"{script_dir}/data/wolves.jpg") + + args = parser.parse_args() + onnx_model_path = args.onnx_model_path + + + onnx_model_name = Path(f"{onnx_model_path}") if not onnx_model_name.exists(): print("Fetching original model...") get_yolo_model(version, str(onnx_model_name)) @@ -66,4 +82,4 @@ def run_inference(onnx_model_file: Path): print("Adding pre/post processing...") add_pre_post_processing_to_yolo(onnx_model_name, onnx_e2e_model_name) print("Testing updated model...") - run_inference(onnx_e2e_model_name) + run_inference(onnx_e2e_model_name, args.test_image)