Skip to content

Commit

Permalink
modified device selection (#21)
Browse files Browse the repository at this point in the history
* modified device selection

device cannot sucessfully control through argments "device"

* update with_sync

* Update ORTWrapper

change the way to create ort session, previous work would load same model twice.

* Update wrapper.py

fixed for lint

* Update wrapper.py

* Update wrapper.py

remove the backslash

* formating

using yapf to format the file

Co-authored-by: AllentDan <AllentDan@yeah.net>
  • Loading branch information
Stephenfang51 and AllentDan authored Dec 31, 2021
1 parent 48bea16 commit f203306
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 13 deletions.
15 changes: 4 additions & 11 deletions mmdeploy/backend/onnxruntime/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,26 +50,19 @@ def __init__(self,
else:
logging.warning(f'The library of onnxruntime custom ops does \
not exist: {ort_custom_op_path}')

sess = ort.InferenceSession(onnx_file, session_options)

device_id = parse_device_id(device)

providers = ['CPUExecutionProvider']
options = [{}]
is_cuda_available = ort.get_device() == 'GPU'
if is_cuda_available:
providers.insert(0, 'CUDAExecutionProvider')
options.insert(0, {'device_id': device_id})
sess.set_providers(providers, options)
providers = [('CUDAExecutionProvider', {'device_id': device_id})] \
if is_cuda_available else ['CPUExecutionProvider']
sess = ort.InferenceSession(
onnx_file, session_options, providers=providers)
if output_names is None:
output_names = [_.name for _ in sess.get_outputs()]
self.sess = sess
self.io_binding = sess.io_binding()
self.device_id = device_id
self.is_cuda_available = is_cuda_available
self.device_type = 'cuda' if is_cuda_available else 'cpu'

super().__init__(output_names)

def forward(self, inputs: Dict[str,
Expand Down
4 changes: 2 additions & 2 deletions tools/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,15 @@ def main():

device_id = parse_device_id(args.device)

model = MMDataParallel(model, device_ids=[0])
model = MMDataParallel(model, device_ids=[device_id])
# The whole dataset test wrapped a MMDataParallel class outside the module.
# As mmcls.apis.test.py single_gpu_test defined, the MMDataParallel needs
# a 'CLASSES' attribute. So we ensure the MMDataParallel class has the same
# CLASSES attribute as the inside module.
if hasattr(model.module, 'CLASSES'):
model.CLASSES = model.module.CLASSES
if args.speed_test:
with_sync = device_id == 0
with_sync = device_id >= 0
output_file = sys.stdout
if args.log2file:
output_file = args.log2file
Expand Down

0 comments on commit f203306

Please sign in to comment.