From 545a3671eec33bf30c448a7316d1c1b110c6f238 Mon Sep 17 00:00:00 2001 From: Molly Sophia Date: Wed, 31 Jan 2024 21:12:16 +0800 Subject: [PATCH] python pnnx add option to change fp16 parameter (#5320) Signed-off-by: Molly Sophia --- tools/pnnx/python/pnnx/utils/convert.py | 4 +++- tools/pnnx/python/pnnx/utils/export.py | 4 ++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/tools/pnnx/python/pnnx/utils/convert.py b/tools/pnnx/python/pnnx/utils/convert.py index 380655ba001..6e5f78cc644 100644 --- a/tools/pnnx/python/pnnx/utils/convert.py +++ b/tools/pnnx/python/pnnx/utils/convert.py @@ -21,7 +21,7 @@ def convert(ptpath, inputs = None, inputs2 = None, input_shapes = None, input_types = None, input_shapes2 = None, input_types2 = None, device = None, customop = None, moduleop = None, optlevel = None, pnnxparam = None, pnnxbin = None, - pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None): + pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None, fp16 = True): check_type(ptpath, "modelname", [str], "str") check_type(inputs, "inputs", [torch.Tensor, tuple, list], "torch.Tensor or tuple/list of torch.Tensor") @@ -106,6 +106,8 @@ def convert(ptpath, inputs = None, inputs2 = None, input_shapes = None, input_ty command_list.append("ncnnbin=" + ncnnbin) if not (ncnnpy is None): command_list.append("ncnnpy=" + ncnnpy) + if not (fp16 is True): + command_list.append("fp16=0") current_dir = os.getcwd() subprocess.run(command_list, stdout=subprocess.PIPE, text=True, cwd=current_dir) diff --git a/tools/pnnx/python/pnnx/utils/export.py b/tools/pnnx/python/pnnx/utils/export.py index 3a04dae733d..6e24954efcd 100644 --- a/tools/pnnx/python/pnnx/utils/export.py +++ b/tools/pnnx/python/pnnx/utils/export.py @@ -19,7 +19,7 @@ def export(model, ptpath, inputs = None, inputs2 = None, input_shapes = None, in input_shapes2 = None, input_types2 = None, device = None, customop = None, moduleop = None, optlevel = None, pnnxparam = None, pnnxbin = None, pnnxpy = None, pnnxonnx = None, ncnnparam = None, ncnnbin = None, ncnnpy = None, - check_trace=True): + check_trace = True, fp16 = True): if (inputs is None) and (input_shapes is None): raise Exception("inputs or input_shapes should be specified.") if not (input_shapes is None) and (input_types is None): @@ -30,4 +30,4 @@ def export(model, ptpath, inputs = None, inputs2 = None, input_shapes = None, in mod.save(ptpath) from . import convert - return convert(ptpath, inputs, inputs2, input_shapes, input_types, input_shapes2, input_types2, device, customop, moduleop, optlevel, pnnxparam, pnnxbin, pnnxpy, pnnxonnx, ncnnparam, ncnnbin, ncnnpy) + return convert(ptpath, inputs, inputs2, input_shapes, input_types, input_shapes2, input_types2, device, customop, moduleop, optlevel, pnnxparam, pnnxbin, pnnxpy, pnnxonnx, ncnnparam, ncnnbin, ncnnpy, fp16)