Skip to content
This repository has been archived by the owner on Dec 16, 2023. It is now read-only.

Commit

Permalink
fixed onnx
Browse files Browse the repository at this point in the history
  • Loading branch information
jc01rho committed Sep 6, 2022
1 parent b91daf4 commit 17dc1f8
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 13 deletions.
35 changes: 24 additions & 11 deletions selfdrive/modeld/runners/onnx_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,36 +9,46 @@

import onnxruntime as ort # pylint: disable=import-error

def read(sz):
def read(sz, tf8=False):
dd = []
gt = 0
while gt < sz * 4:
st = os.read(0, sz * 4 - gt)
szof = 1 if tf8 else 4
while gt < sz * szof:
st = os.read(0, sz * szof - gt)
assert(len(st) > 0)
dd.append(st)
gt += len(st)
return np.frombuffer(b''.join(dd), dtype=np.float32)
r = np.frombuffer(b''.join(dd), dtype=np.uint8 if tf8 else np.float32).astype(np.float32)
if tf8:
r = r / 255.
return r

def write(d):
os.write(1, d.tobytes())

def run_loop(m):
def run_loop(m, tf8_input=False):
ishapes = [[1]+ii.shape[1:] for ii in m.get_inputs()]
keys = [x.name for x in m.get_inputs()]

# run once to initialize CUDA provider
if "CUDAExecutionProvider" in m.get_providers():
m.run(None, dict(zip(keys, [np.zeros(shp, dtype=np.float32) for shp in ishapes])))

print("ready to run onnx model", keys, ishapes, file=sys.stderr)
while 1:
inputs = []
for shp in ishapes:
for k, shp in zip(keys, ishapes):
ts = np.product(shp)
#print("reshaping %s with offset %d" % (str(shp), offset), file=sys.stderr)
inputs.append(read(ts).reshape(shp))
inputs.append(read(ts, (k=='input_img' and tf8_input)).reshape(shp))
ret = m.run(None, dict(zip(keys, inputs)))
#print(ret, file=sys.stderr)
for r in ret:
write(r)


if __name__ == "__main__":
print(sys.argv, file=sys.stderr)
print("Onnx available providers: ", ort.get_available_providers(), file=sys.stderr)
options = ort.SessionOptions()
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
Expand All @@ -54,7 +64,10 @@ def run_loop(m):
options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
provider = 'CPUExecutionProvider'

print("Onnx selected provider: ", [provider], file=sys.stderr)
ort_session = ort.InferenceSession(sys.argv[1], options, providers=[provider])
print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
run_loop(ort_session)
try:
print("Onnx selected provider: ", [provider], file=sys.stderr)
ort_session = ort.InferenceSession(sys.argv[1], options, providers=[provider])
print("Onnx using ", ort_session.get_providers(), file=sys.stderr)
run_loop(ort_session, tf8_input=("--use_tf8" in sys.argv))
except KeyboardInterrupt:
pass
2 changes: 1 addition & 1 deletion selfdrive/modeld/runners/onnxmodel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "common/swaglog.h"
#include "common/util.h"

ONNXModel::ONNXModel(const char *path, float *_output, size_t _output_size, int runtime, bool _use_extra, bool _use_tf8) {
ONNXModel::ONNXModel(const char *path, float *_output, size_t _output_size, int runtime, bool _use_extra, bool _use_tf8, cl_context context) {
LOGD("loading model %s", path);

output = _output;
Expand Down
2 changes: 1 addition & 1 deletion selfdrive/modeld/runners/onnxmodel.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

class ONNXModel : public RunModel {
public:
ONNXModel(const char *path, float *output, size_t output_size, int runtime, bool use_extra = false, bool _use_tf8 = false);
ONNXModel(const char *path, float *output, size_t output_size, int runtime, bool use_extra = false, bool _use_tf8 = false, cl_context context = NULL);
~ONNXModel();
void addRecurrent(float *state, int state_size);
void addDesire(float *state, int state_size);
Expand Down

0 comments on commit 17dc1f8

Please sign in to comment.