Skip to content

Commit

Permalink
runs but
Browse files Browse the repository at this point in the history
  • Loading branch information
Comma Device committed Aug 28, 2024
1 parent 34232ad commit 3db37c0
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions selfdrive/modeld/dmonitoringmodeld.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from openpilot.common.params import Params
from openpilot.common.realtime import set_realtime_priority
from openpilot.selfdrive.modeld.runners import ModelRunner, Runtime
from openpilot.selfdrive.modeld.models.commonmodel_pyx import sigmoid
from openpilot.selfdrive.modeld.models.commonmodel_pyx import sigmoid, CLContext

CALIB_LEN = 3
REG_SCALE = 0.25
Expand Down Expand Up @@ -56,14 +56,14 @@ class ModelState:
output: np.ndarray
model: ModelRunner

def __init__(self):
def __init__(self, cl_ctx):
assert ctypes.sizeof(DMonitoringModelResult) == OUTPUT_SIZE * ctypes.sizeof(ctypes.c_float)
self.output = np.zeros(OUTPUT_SIZE, dtype=np.float32)
self.inputs = {
'input_img': np.zeros(MODEL_HEIGHT * MODEL_WIDTH, dtype=np.uint8),
'calib': np.zeros(CALIB_LEN, dtype=np.float32)}

self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.GPU, True, None)
self.model = ModelRunner(MODEL_PATHS, self.output, Runtime.GPU, False, cl_ctx)
self.model.addInput("input_img", None)
self.model.addInput("calib", self.inputs['calib'])

Expand All @@ -77,7 +77,7 @@ def run(self, buf:VisionBuf, calib:np.ndarray) -> tuple[np.ndarray, float]:
input_data[:] = buf_data[v_offset:v_offset+MODEL_HEIGHT, h_offset:h_offset+MODEL_WIDTH]

t1 = time.perf_counter()
self.model.setInputBuffer("input_img", self.inputs['input_img'].view(np.float32))
self.model.setInputBuffer("input_img", self.inputs['input_img'].astype(np.float32))
self.model.execute()
t2 = time.perf_counter()
return self.output, t2 - t1
Expand Down Expand Up @@ -117,12 +117,13 @@ def main():
gc.disable()
set_realtime_priority(1)

model = ModelState()
cl_context = CLContext()
model = ModelState(cl_context)
cloudlog.warning("models loaded, dmonitoringmodeld starting")
Params().put_bool("DmModelInitialized", True)

cloudlog.warning("connecting to driver stream")
vipc_client = VisionIpcClient("camerad", VisionStreamType.VISION_STREAM_DRIVER, True)
vipc_client = VisionIpcClient("camerad", VisionStreamType.VISION_STREAM_DRIVER, True, cl_context)
while not vipc_client.connect(False):
time.sleep(0.1)
assert vipc_client.is_connected()
Expand Down

0 comments on commit 3db37c0

Please sign in to comment.