From b5e822a4384ce626fc961b722a100edc372aef77 Mon Sep 17 00:00:00 2001 From: ZwX1616 Date: Wed, 18 Sep 2024 14:08:13 -0700 Subject: [PATCH] set input tensor dtype --- openpilot/compile2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 1ee84f91ebb3..e9fa3d4cfddc 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -30,9 +30,10 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: onnx_model = onnx.load(io.BytesIO(onnx_data)) run_onnx = get_run_onnx(onnx_model) input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} + input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} # run the model - inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()} + inputs = {k:Tensor.empty(*shp, dtype=getattr(dtypes, input_types[k].name)) for k,shp in input_shapes.items()} ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous() schedule = ret.lazydata.schedule()