Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bug] try to output probability map #2050

Closed
3 tasks done
yuanyuangoo opened this issue May 5, 2023 · 20 comments
Closed
3 tasks done

[Bug] try to output probability map #2050

yuanyuangoo opened this issue May 5, 2023 · 20 comments

Comments

@yuanyuangoo
Copy link

Checklist

  • I have searched related issues but cannot get the expected help.
  • 2. I have read the FAQ documentation but cannot get the expected help.
  • 3. The bug has not been fixed in the latest version.

Describe the bug

I try to convert a mmseg model to tensorrt engine to output probability map by put codebase_config = dict(with_argmax=False) in th config file. As a result, it can output probability map. But I got this.
[2023-05-05 14:21:21.583] [mmdeploy] [error] [segment.cpp:45] probability feat map with shape: [1, 8, 1024, 1224] requires with_argmax_=false

Reproduction

import time
import numpy as np

from mmdeploy_python import Segmentor

from mmdeploy_runtime import Segmentor
segmentor = Segmentor(
model_path='./deeplabv3plus_bud/', device_name='cuda', device_id=0)
prev_frame_time = 0
new_frame_time = 0
im = np.zeros((1224, 1024, 3))
seg = segmentor(im)

Environment

05/05 14:29:58 - mmengine - INFO - 

05/05 14:29:58 - mmengine - INFO - **********Environmental information**********
05/05 14:29:59 - mmengine - INFO - sys.platform: linux
05/05 14:29:59 - mmengine - INFO - Python: 3.8.16 (default, Mar  2 2023, 03:21:46) [GCC 11.2.0]
05/05 14:29:59 - mmengine - INFO - CUDA available: True
05/05 14:29:59 - mmengine - INFO - numpy_random_seed: 2147483648
05/05 14:29:59 - mmengine - INFO - GPU 0: Quadro P3200
05/05 14:29:59 - mmengine - INFO - CUDA_HOME: /usr/local/cuda
05/05 14:29:59 - mmengine - INFO - NVCC: Cuda compilation tools, release 11.7, V11.7.99
05/05 14:29:59 - mmengine - INFO - GCC: gcc (Ubuntu 9.4.0-1ubuntu1~20.04.1) 9.4.0
05/05 14:29:59 - mmengine - INFO - PyTorch: 2.0.0
05/05 14:29:59 - mmengine - INFO - PyTorch compiling details: PyTorch built with:
  - GCC 9.3
  - C++ Version: 201703
  - Intel(R) oneAPI Math Kernel Library Version 2021.4-Product Build 20210904 for Intel(R) 64 architecture applications
  - Intel(R) MKL-DNN v2.7.3 (Git Hash 6dbeffbae1f23cbbeae17adb7b5b13f1f37c080e)
  - OpenMP 201511 (a.k.a. OpenMP 4.5)
  - LAPACK is enabled (usually provided by MKL)
  - NNPACK is enabled
  - CPU capability usage: AVX2
  - CUDA Runtime 11.7
  - NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;arch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  - CuDNN 8.5
  - Magma 2.6.1
  - Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.7, CUDNN_VERSION=8.5.0, CXX_COMPILER=/opt/rh/devtoolset-9/root/usr/bin/c++, CXX_FLAGS= -D_GLIBCXX_USE_CXX11_ABI=0 -fabi-version=11 -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -DNDEBUG -DUSE_KINETO -DLIBKINETO_NOROCTRACER -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -O2 -fPIC -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Werror=bool-operation -Wnarrowing -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_DISABLE_GPU_ASSERTS=ON, TORCH_VERSION=2.0.0, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF, 

05/05 14:29:59 - mmengine - INFO - TorchVision: 0.15.1+cu117
05/05 14:29:59 - mmengine - INFO - OpenCV: 4.7.0
05/05 14:29:59 - mmengine - INFO - MMEngine: 0.7.2
05/05 14:29:59 - mmengine - INFO - MMCV: 2.0.0
05/05 14:29:59 - mmengine - INFO - MMCV Compiler: GCC 9.4
05/05 14:29:59 - mmengine - INFO - MMCV CUDA Compiler: 11.7
05/05 14:29:59 - mmengine - INFO - MMDeploy: 1.0.0+840adcf
05/05 14:29:59 - mmengine - INFO - 

05/05 14:29:59 - mmengine - INFO - **********Backend information**********
05/05 14:29:59 - mmengine - INFO - tensorrt:	8.2.4.2
05/05 14:29:59 - mmengine - INFO - tensorrt custom ops:	Available
05/05 14:30:00 - mmengine - INFO - ONNXRuntime:	1.14.1
05/05 14:30:00 - mmengine - INFO - ONNXRuntime-gpu:	1.13.1
05/05 14:30:00 - mmengine - INFO - ONNXRuntime custom ops:	NotAvailable
05/05 14:30:00 - mmengine - INFO - pplnn:	None
05/05 14:30:00 - mmengine - INFO - ncnn:	None
05/05 14:30:00 - mmengine - INFO - snpe:	None
05/05 14:30:00 - mmengine - INFO - openvino:	2022.3.0
05/05 14:30:00 - mmengine - INFO - torchscript:	1.13.1
05/05 14:30:00 - mmengine - INFO - torchscript custom ops:	NotAvailable
05/05 14:30:00 - mmengine - INFO - rknn-toolkit:	None
05/05 14:30:00 - mmengine - INFO - rknn-toolkit2:	None
05/05 14:30:00 - mmengine - INFO - ascend:	None
05/05 14:30:00 - mmengine - INFO - coreml:	None
05/05 14:30:00 - mmengine - INFO - tvm:	None
05/05 14:30:00 - mmengine - INFO - vacc:	None
05/05 14:30:00 - mmengine - INFO - 

05/05 14:30:00 - mmengine - INFO - **********Codebase information**********
05/05 14:30:00 - mmengine - INFO - mmdet:	3.0.0
05/05 14:30:00 - mmengine - INFO - mmseg:	1.0.0
05/05 14:30:00 - mmengine - INFO - mmpretrain:	1.0.0rc7
05/05 14:30:00 - mmengine - INFO - mmocr:	1.0.0
05/05 14:30:00 - mmengine - INFO - mmedit:	None
05/05 14:30:00 - mmengine - INFO - mmdet3d:	None
05/05 14:30:00 - mmengine - INFO - mmpose:	1.0.0
05/05 14:30:00 - mmengine - INFO - mmrotate:	None
05/05 14:30:00 - mmengine - INFO - mmaction:	None
05/05 14:30:00 - mmengine - INFO - mmrazor:	None

Error traceback

[2023-05-05 14:21:21.583] [mmdeploy] [error] [segment.cpp:45] probability feat map with shape: [1, 8, 1024, 1224] requires `with_argmax_=false`
@RunningLeon
Copy link
Collaborator

@yuanyuangoo hi, the channel number of preprocessed image is 8, which should be 3. Could you post here your deploy config and pipeline.json?
for deploy config, you can run python -c 'from mmengine import Config;print(Config.fromfile("xxx.py").pretty_text)'

@yuanyuangoo
Copy link
Author

pipeline.json is
{
"pipeline": {
"input": [
"img"
],
"output": [
"post_output"
],
"tasks": [
{
"type": "Task",
"module": "Transform",
"name": "Preprocess",
"input": [
"img"
],
"output": [
"prep_output"
],
"transforms": [
{
"type": "LoadImageFromFile"
},
{
"type": "Resize",
"keep_ratio": false,
"size": [
1024,
1224
]
},
{
"type": "Normalize",
"mean": [
123.675,
116.28,
103.53
],
"std": [
58.395,
57.12,
57.375
],
"to_rgb": true
},
{
"type": "ImageToTensor",
"keys": [
"img"
]
},
{
"type": "Collect",
"keys": [
"img"
],
"meta_keys": [
"img_shape",
"pad_shape",
"ori_shape",
"img_norm_cfg",
"scale_factor"
]
}
]
},
{
"name": "depthwiseseparableaspp",
"type": "Task",
"module": "Net",
"is_batched": false,
"input": [
"prep_output"
],
"output": [
"infer_output"
],
"input_map": {
"img": "input"
},
"output_map": {}
},
{
"type": "Task",
"module": "mmseg",
"name": "postprocess",
"component": "ResizeMask",
"params": {
"type": "DepthwiseSeparableASPPHead",
"in_channels": 512,
"in_index": 3,
"channels": 128,
"dilations": [
1,
12,
24,
36
],
"c1_in_channels": 64,
"c1_channels": 12,
"dropout_ratio": 0.1,
"num_classes": 8,
"norm_cfg": {
"type": "SyncBN",
"requires_grad": true
},
"align_corners": false,
"loss_decode": {
"type": "LovaszLoss",
"per_image": true,
"loss_weight": 1.0,
"class_weight": [
0.89,
0.94,
0.9,
2.91,
3.91,
1.09,
2.91,
0.91
]
}
},
"output": [
"post_output"
],
"input": [
"prep_output",
"infer_output"
]
}
]
}
}

@yuanyuangoo
Copy link
Author

deploy.json
{
"version": "1.0.0",
"task": "Segmentor",
"models": [
{
"name": "depthwiseseparableaspp",
"net": "end2end.engine",
"weights": "",
"backend": "tensorrt",
"precision": "FP32",
"batch_size": 1,
"dynamic_shape": false
}
],
"customs": []
}

@RunningLeon
Copy link
Collaborator

RunningLeon commented May 5, 2023

@yuanyuangoo hi, could you try again the mmdeploy with commit id after this PR #2038?

@yuanyuangoo
Copy link
Author

Hi RunningLeon, I am on branch main and my branch is up to date with 'origin/main'.

@RunningLeon
Copy link
Collaborator

RunningLeon commented May 6, 2023

@yuanyuangoo hi, sorry for the trouble. You could add "with_argmax": false in pipeline.json. If you want to reconvert the model, you could change here to postprocess['params']['with_argmax'] = with_argmax. If your test is ok, could you kindly create a PR to fix it?

example pipelin.json

{
    "pipeline": {
        "input": [
            "img"
        ],
        "output": [
            "post_output"
        ],
        "tasks": [
            {
                "type": "Task",
                "module": "Transform",
                "name": "Preprocess",
                "input": [
                    "img"
                ],
                "output": [
                    "prep_output"
                ],
                "transforms": [
                    {
                        "type": "LoadImageFromFile"
                    },
                    {
                        "type": "Resize",
                        "keep_ratio": false,
                        "size": [
                            512,
                            512
                        ]
                    },
                    {
                        "type": "Normalize",
                        "mean": [
                            123.675,
                            116.28,
                            103.53
                        ],
                        "std": [
                            58.395,
                            57.12,
                            57.375
                        ],
                        "to_rgb": true
                    },
                    {
                        "type": "ImageToTensor",
                        "keys": [
                            "img"
                        ]
                    },
                    {
                        "type": "Collect",
                        "keys": [
                            "img"
                        ],
                        "meta_keys": [
                            "img_shape",
                            "pad_shape",
                            "ori_shape",
                            "img_norm_cfg",
                            "scale_factor"
                        ]
                    }
                ]
            },
            {
                "name": "depthwiseseparableaspp",
                "type": "Task",
                "module": "Net",
                "is_batched": false,
                "input": [
                    "prep_output"
                ],
                "output": [
                    "infer_output"
                ],
                "input_map": {
                    "img": "input"
                },
                "output_map": {}
            },
            {
                "type": "Task",
                "module": "mmseg",
                "name": "postprocess",
                "component": "ResizeMask",
                "params": {
                    "type": "DepthwiseSeparableASPPHead",
                    "in_channels": 2048,
                    "in_index": 3,
                    "channels": 512,
                    "dilations": [
                        1,
                        12,
                        24,
                        36
                    ],
                    "c1_in_channels": 256,
                    "c1_channels": 48,
                    "dropout_ratio": 0.1,
                    "num_classes": 19,
                    "norm_cfg": {
                        "type": "SyncBN",
                        "requires_grad": true
                    },
                    "align_corners": false,
                    "loss_decode": {
                        "type": "CrossEntropyLoss",
                        "use_sigmoid": false,
                        "loss_weight": 1.0
                    },
                    "with_argmax": true
                },
                "output": [
                    "post_output"
                ],
                "input": [
                    "prep_output",
                    "infer_output"
                ]
            }
        ]
    }
}

@yuanyuangoo
Copy link
Author

Thank you @RunningLeon, It works. Could it be possible that the pipeline.json generated by mmdeploy/tools/deploy.py includes "with_argmax": false automatically, when I have specify "codebase_config = dict(with_argmax=False)" in my mmdeploy/configs/mmseg/segmentation_tensorrt-fp16_static-1024x1224.py file? So that I don't have to manually add "with_argmax": false to pipeline.json everytime I deploy model.

@yuanyuangoo
Copy link
Author

Another issue is that when I output the possibility map following your instruction, the processing speed is about 8 times slower than just outputting the segmentation results. Maybe It's caused by the fact that the possibility map is larger than the segmentation map, so it takes a longer time to transfer from GPU to the local machine. Is there any way to solve this issue, please?

@yuanyuangoo
Copy link
Author

Additionally, to get the actual confidence of every class, I still need to add a softmax after the output. Could it be possible to support this in the tensorrt engine?

@RunningLeon
Copy link
Collaborator

  1. you need to add codebase_config = dict(with_argmax=False) in your deploy config and update mmdeploy to this PR fix with_argmax for mmseg #2056
  2. How did you test the inference speed?
  3. Do you mean you need to have two outputs: score and label? This is customized request, and you could done it by your self by adding the two outputs when converting onnx.

@yuanyuangoo
Copy link
Author

The script I use to test speed is as


import time
import numpy as np
from mmdeploy_runtime import Segmentor
segmentor = Segmentor(
    model_path='./deeplabv3plus_bud/', device_name='cuda', device_id=0)
im = np.zeros((1224, 1024, 3))
while(True):
    t= time.time()
    seg = segmentor(im)
    print(1/(time.time()-t))

On my 3080 graphic card under tensorrt-fp16 model, I got about 7 fps with_argmax=False, and 50 fps with_argmax=True.

@yuanyuangoo
Copy link
Author

3. Do you mean you need to have two outputs: score and label? This is customized request, and you could done it by your self by adding the two outputs when converting onnx.

Can you give me a hint on how to achieve that, please?

@yuanyuangoo
Copy link
Author

@RunningLeon

I tried to replace in file mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py:30

    if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
        return seg_logit

to

    if get_codebase_config(ctx.cfg).get('with_argmax', True) is False:
        return seg_logit.softmax(dim=1)

and now I can get all confidence right.

But the processing speed is still like 10 times lower than just outputting the class maps. Is there any way to solve this, please?

@RunningLeon
Copy link
Collaborator

@yuanyuangoo, hi, you are right for outputing scores. Because you add a softmax layer, the fps is slower. As for why it's 10 times slower, you have to profiler by yourself and determine the proportion of each layer to the total consuming time.

@yuanyuangoo
Copy link
Author

@RunningLeon I just tried to tensorrt python api to load the engine file and run the inference, and I can get more than 20 fps, including all the preprocessing and post processing . But when I use mmdeploy python runtime, I can only have like 3-5 fps. From the comparison, it's clear that mmdeploy has potential to be much faster.

@RunningLeon
Copy link
Collaborator

RunningLeon commented May 11, 2023

@yuanyuangoo, Comparision between two backends needs careful design. How did you test with TensorRT python api? Pls. make sure the pre-and post-processing are included. BTW, from personal experience, ArgMax layer is exstreamly slow for segmentation models even in TensorRT.

@yuanyuangoo
Copy link
Author

from cupyx.scipy import ndimage
import cupy as cp
import cv2
import tensorrt as trt

class TRTWrapper:
    def __init__(self, engine_file):
        self.runtime = trt.Runtime(trt.Logger(trt.Logger.WARNING))
        self.engine = self.load_engine(engine_file)
        self.context = self.engine.create_execution_context()
        self.stream = cp.cuda.Stream(non_blocking=False)

    def load_engine(self, engine_file):
        with open(engine_file, "rb") as f:
           return self.runtime.deserialize_cuda_engine(f.read())

    def infer_with_cupy(self, original_img):
        
        # preprocess image
        input_image = preprocess_with_cupy(original_img).astype('float32')
        input_image = cp.ascontiguousarray(input_image)
        
        # allocate memory for output and input to GPU
        bindings = []
        bindings.append(int(input_image.data.ptr))
        for binding in self.engine:
            binding_idx = self.engine.get_binding_index(binding)
            size = trt.volume(self.context.get_binding_shape(binding_idx))
            dtype = trt.nptype(self.engine.get_binding_dtype(binding))
            if not self.engine.binding_is_input(binding):
                if binding != 'output':
                    label_mem = cp.ascontiguousarray(
                        cp.zeros(size, dtype=dtype))
                    bindings.append(int(label_mem.data.ptr))
                else:
                    conf_mem = cp.ascontiguousarray(
                        cp.zeros(size, dtype=dtype))
                    bindings.append(int(conf_mem.data.ptr))
                    
        # inference
        self.context.execute_async_v2(bindings=bindings, stream_handle=self.stream.ptr)
        self.stream.synchronize()

        # postprocess
        label, conf = postprocess_with_softmax_cupy(conf_mem, label_mem)

        # visualize and save image
        # img = add_color_images_cupy(original_img, label, PALETTE)
        # img = img * conf[:, :, cp.newaxis]
        # cv2.imwrite(filename='1_output.jpg', img=cp.asnumpy(img))
        return cp.asnumpy(label), cp.asnumpy(conf)


PALETTE = [[0, 0, 0], [61, 61, 245], [64, 128, 32], [255, 0, 41], [
    250, 85, 246], [241, 177, 195], [255, 204, 51], [255, 255, 255]]


def preprocess_with_cupy(image):
    #resize
    image = cv2.resize(image, (1224, 1024))
    # change to RGB
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Mean normalization
    normalization_parameters = {"mean": cp.array([123.675, 116.28, 103.53]),
                                "std": cp.array([58.395, 57.12, 57.375])}
    data = (cp.asarray(image).astype('float32') -
            normalization_parameters["mean"])/normalization_parameters["std"]

    # Switch from HWC to to CHW order
    data = cp.moveaxis(data, 2, 0)

    # Add batch dimension
    data = cp.expand_dims(data, axis=0)
    return data


def add_color_images_cupy(img, seg, palette):
    color_seg = cp.zeros((seg.shape[0], seg.shape[1], 3), dtype=cp.uint8)
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
    # convert to BGR
    color_seg = color_seg[..., ::-1]

    img = cp.asarray(img) * 0.5 + color_seg * 0.5
    return img


def postprocess_with_softmax_cupy(conf, label):

    # reshape to 2D
    conf = cp.reshape(conf, (1224, 1024), 'F').T
    label = cp.reshape(label, (1224, 1024), 'F').T

    #resize to original size
    conf = ndimage.zoom(conf, zoom=2.0, output=cp.float16)
    label = ndimage.zoom(label, zoom=2.0, order=0, output=cp.int8)
    return label, conf
import time

def main():
    # load engine
    TrtWrapper = TRTWrapper('./deeplabv3plus_bud/end2end.engine')

    # read image, image size should be 2448x2048
    input_image = cv2.imread('./1.png')

    # inference, return label and confidence map
    while True:
        t1=time.time()
        conf, label = TrtWrapper.infer_with_cupy(input_image)
        print(time.time()-t1, 1/(time.time()-t1))
    print(conf.shape, label.shape, 'success')


if __name__ == '__main__':
    main()

@yuanyuangoo
Copy link
Author

The above is the inference script.

I modified code in /root/workspace/mmdeploy/mmdeploy/codebase/mmseg/models/segmentors/encoder_decoder.py, replace "seg_logit" in line 31 to "seg_logit.softmax(dim=1).max(dim=1, keepdim=True)"

I also added "codebase_config = dict(with_argmax=False)" to the end of "/root/workspace/mmdeploy/configs/mmseg/segmentation_tensorrt-fp16_static-1024x1224.py"

After this I converted pytorch model to trt engine, ran inference, and I got like 30+ fps.

@github-actions
Copy link

This issue is marked as stale because it has been marked as invalid or awaiting response for 7 days without any further response. It will be closed in 5 days if the stale label is not removed or if there is no further response.

@github-actions github-actions bot added the Stale label May 19, 2023
@github-actions
Copy link

This issue is closed because it has been stale for 5 days. Please open a new issue if you have similar issues or you have any new updates now.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale May 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants