From 9dda6d260db0255750bacff61e3cee1e580567e1 Mon Sep 17 00:00:00 2001 From: Bruce Wayne Date: Sat, 28 Sep 2024 18:14:12 -0700 Subject: [PATCH] Works with uint8 models --- openpilot/compile2.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/openpilot/compile2.py b/openpilot/compile2.py index 6d21d48898f7..3df3eb452946 100644 --- a/openpilot/compile2.py +++ b/openpilot/compile2.py @@ -22,7 +22,7 @@ from tinygrad.features.image import fix_schedule_for_images Device.DEFAULT = "GPU" -def get_schedule(onnx_data, force_input_type=None) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: +def get_schedule(onnx_data, supercombo_dtypes=False) -> Tuple[List[ScheduleItem], List[ScheduleItem]]: Tensor.no_grad = True Tensor.training = False @@ -30,11 +30,16 @@ def get_schedule(onnx_data, force_input_type=None) -> Tuple[List[ScheduleItem], onnx_model = onnx.load(io.BytesIO(onnx_data)) run_onnx = get_run_onnx(onnx_model) input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input} - input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} + input_types_onnx = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input} # run the model - inputs = {k:Tensor.empty(*shp, - dtype=force_input_type if force_input_type is not None else getattr(dtypes, input_types[k].name)) for k,shp in input_shapes.items()} + input_types = {k:getattr(dtypes, input_types_onnx[k].name)for k in input_types_onnx.keys()} + print(input_types) + if supercombo_dtypes: + input_types = {k:dtypes.float32 if 'img' not in k else dtypes.uint8 for k in input_types.keys()} + print(input_types) + + inputs = {k:Tensor.empty(*shp, dtype=input_types[k]) for k,shp in input_shapes.items()} ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous() schedule = ret.lazydata.schedule() @@ -139,8 +144,8 @@ def thneed_test_onnx(onnx_data, output_fn): #exit(0) # this is a hack due to supercombo being converted with f16 inputs but it uses f32 at runtime - force_input_type = dtypes.float32 if 'supercombo' in onnx_fn else None - schedule, schedule_independent, inputs = get_schedule(onnx_data, force_input_type=force_input_type) + supercombo = 'supercombo' + schedule, schedule_independent, inputs = get_schedule(onnx_data, supercombo_dtypes=supercombo) schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps) print(f"{len(schedule_input)} inputs")