Skip to content

Commit

Permalink
Works with uint8 models
Browse files Browse the repository at this point in the history
  • Loading branch information
haraschax committed Sep 29, 2024
1 parent 3e15fa0 commit 9dda6d2
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions openpilot/compile2.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,24 @@
from tinygrad.features.image import fix_schedule_for_images
Device.DEFAULT = "GPU"

def get_schedule(onnx_data, force_input_type=None) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
def get_schedule(onnx_data, supercombo_dtypes=False) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
Tensor.no_grad = True
Tensor.training = False

# load the model
onnx_model = onnx.load(io.BytesIO(onnx_data))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}
input_types_onnx = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}

# run the model
inputs = {k:Tensor.empty(*shp,
dtype=force_input_type if force_input_type is not None else getattr(dtypes, input_types[k].name)) for k,shp in input_shapes.items()}
input_types = {k:getattr(dtypes, input_types_onnx[k].name)for k in input_types_onnx.keys()}
print(input_types)
if supercombo_dtypes:
input_types = {k:dtypes.float32 if 'img' not in k else dtypes.uint8 for k in input_types.keys()}
print(input_types)

inputs = {k:Tensor.empty(*shp, dtype=input_types[k]) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = ret.lazydata.schedule()

Expand Down Expand Up @@ -139,8 +144,8 @@ def thneed_test_onnx(onnx_data, output_fn):
#exit(0)

# this is a hack due to supercombo being converted with f16 inputs but it uses f32 at runtime
force_input_type = dtypes.float32 if 'supercombo' in onnx_fn else None
schedule, schedule_independent, inputs = get_schedule(onnx_data, force_input_type=force_input_type)
supercombo = 'supercombo'
schedule, schedule_independent, inputs = get_schedule(onnx_data, supercombo_dtypes=supercombo)
schedule, schedule_input = partition(schedule, lambda x: x.ast.op not in LoadOps)
print(f"{len(schedule_input)} inputs")

Expand Down

0 comments on commit 9dda6d2

Please sign in to comment.