Skip to content

Commit

Permalink
onnxmodel fp16_to_fp32: misc improvements (commaai#33615)
Browse files Browse the repository at this point in the history
  • Loading branch information
ZwX1616 authored Sep 20, 2024
1 parent 8d50970 commit b297663
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions selfdrive/modeld/runners/onnxmodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@ def attributeproto_fp16_to_fp32(attr):
attr.data_type = 1
attr.raw_data = float32_list.astype(np.float32).tobytes()

def convert_fp16_to_fp32(path):
model = onnx.load(path)
def convert_fp16_to_fp32(onnx_path_or_bytes):
if isinstance(onnx_path_or_bytes, bytes):
model = onnx.load_from_string(onnx_path_or_bytes)
elif isinstance(onnx_path_or_bytes, str):
model = onnx.load(onnx_path_or_bytes)

for i in model.graph.initializer:
if i.data_type == 10:
attributeproto_fp16_to_fp32(i)
for i in itertools.chain(model.graph.input, model.graph.output):
if i.type.tensor_type.elem_type == 10:
i.type.tensor_type.elem_type = 1
for i in model.graph.node:
if i.op_type == 'Cast' and i.attribute[0].i == 10:
i.attribute[0].i = 1
for a in i.attribute:
if hasattr(a, 't'):
if a.t.data_type == 10:
Expand Down

0 comments on commit b297663

Please sign in to comment.