Skip to content

Commit

Permalink
set input tensor dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
ZwX1616 committed Sep 18, 2024
1 parent f51aa0f commit b5e822a
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion openpilot/compile2.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@ def get_schedule(onnx_data) -> Tuple[List[ScheduleItem], List[ScheduleItem]]:
onnx_model = onnx.load(io.BytesIO(onnx_data))
run_onnx = get_run_onnx(onnx_model)
input_shapes = {inp.name:tuple(x.dim_value for x in inp.type.tensor_type.shape.dim) for inp in onnx_model.graph.input}
input_types = {inp.name:onnx.helper.tensor_dtype_to_np_dtype(inp.type.tensor_type.elem_type) for inp in onnx_model.graph.input}

# run the model
inputs = {k:Tensor.empty(*shp) for k,shp in input_shapes.items()}
inputs = {k:Tensor.empty(*shp, dtype=getattr(dtypes, input_types[k].name)) for k,shp in input_shapes.items()}
ret: Tensor = next(iter(run_onnx(inputs).values())).cast(dtypes.float32).contiguous()
schedule = ret.lazydata.schedule()

Expand Down

0 comments on commit b5e822a

Please sign in to comment.