Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Unexpected result on multi-batch gather. #2299

Closed
grimoire opened this issue Sep 6, 2022 · 4 comments
Closed

Unexpected result on multi-batch gather. #2299

grimoire opened this issue Sep 6, 2022 · 4 comments
Assignees
Labels
triaged Issue has been triaged by maintainers

Comments

@grimoire
Copy link

grimoire commented Sep 6, 2022

Description

Gathering topk index on a multi-batch tensor gives unexpected results.
Note that if we replace the profile with:

    C=10
    input_shapes = {
        'input': {
            'min_shape': [1, C, 4],
            'opt_shape': [2, C, 4],
            'max_shape': [4, C, 4]
        }
    }

Given the right result.

Please read the code below for more detail.

Environment

TensorRT Version: 8.4.1.5
NVIDIA GPU: 2060s
NVIDIA Driver Version: 510.85.02
CUDA Version: 11.3
CUDNN Version: 8.2.1
Operating System: Ubuntu18.04
Python Version (if applicable): 3.7
Tensorflow Version (if applicable):
PyTorch Version (if applicable): 1.10.0
Baremetal or Container (if so, version):

Relevant Files

Steps To Reproduce

import torch
import tensorrt as trt
import onnx
from typing import Dict


def from_onnx(onnx_model, input_shapes, max_workspace_size):
    logger = trt.Logger(trt.Logger.INFO)
    builder = trt.Builder(logger)
    EXPLICIT_BATCH = 1 << (int)(
        trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
    network = builder.create_network(EXPLICIT_BATCH)

    # parse onnx
    parser = trt.OnnxParser(network, logger)

    if isinstance(onnx_model, str):
        onnx_model = onnx.load(onnx_model)

    if not parser.parse(onnx_model.SerializeToString()):
        error_msgs = ''
        for error in range(parser.num_errors):
            error_msgs += f'{parser.get_error(error)}\n'
        raise RuntimeError(f'Failed to parse onnx, {error_msgs}')

    config = builder.create_builder_config()
    config.max_workspace_size = max_workspace_size

    profile = builder.create_optimization_profile()

    for input_name, param in input_shapes.items():
        min_shape = param['min_shape']
        opt_shape = param['opt_shape']
        max_shape = param['max_shape']
        profile.set_shape(input_name, min_shape, opt_shape, max_shape)
    config.add_optimization_profile(profile)

    engine = builder.build_engine(network, config)

    return engine


TORCH_DTYPE_MAP = {
    trt.bool: torch.bool,
    trt.int8: torch.int8,
    trt.int32: torch.int32,
    trt.float16: torch.float16,
    trt.float32: torch.float32
}


class TRTWrapper(torch.nn.Module):

    def __init__(self, engine: trt.ICudaEngine):
        super().__init__()
        self.engine = engine

        if not isinstance(self.engine, trt.ICudaEngine):
            raise TypeError(f'`engine` should be str or trt.ICudaEngine, \
                but given: {type(self.engine)}')

        self.context = self.engine.create_execution_context()
        self.__load_io_names()

    def __load_io_names(self):
        """Load input/output names from engine."""
        names = [_ for _ in self.engine]
        input_names = list(filter(self.engine.binding_is_input, names))
        self._input_names = input_names

        output_names = list(set(names) - set(input_names))
        self._output_names = output_names

    def forward(self, inputs: Dict[str,
                                   torch.Tensor]) -> Dict[str, torch.Tensor]:
        """Run forward inference.

        Args:
            inputs (Dict[str, torch.Tensor]): The input name and tensor pairs.

        Return:
            Dict[str, torch.Tensor]: The output name and tensor pairs.
        """
        bindings = [None] * (len(self._input_names) + len(self._output_names))

        for input_name, input_tensor in inputs.items():
            idx = self.engine.get_binding_index(input_name)

            # All input tensors must be gpu variables
            input_tensor = input_tensor.contiguous()
            if input_tensor.dtype == torch.long:
                input_tensor = input_tensor.int()
            self.context.set_binding_shape(idx, tuple(input_tensor.shape))
            bindings[idx] = input_tensor.contiguous().data_ptr()

        # create output tensors
        outputs = {}
        for output_name in self._output_names:
            idx = self.engine.get_binding_index(output_name)
            dtype = TORCH_DTYPE_MAP[self.engine.get_binding_dtype(idx)]
            shape = tuple(self.context.get_binding_shape(idx))

            output = torch.empty(size=shape, dtype=dtype, device='cuda')
            outputs[output_name] = output
            bindings[idx] = output.data_ptr()

        self.context.execute_async_v2(bindings,
                                      torch.cuda.current_stream().cuda_stream)

        return outputs


class TestModel(torch.nn.Module):

    def __init__(self) -> None:
        super().__init__()

    def forward(self, x):
        batch_size = x.size(0)
        C = x.size(1)
        max_x, _ = x.max(-1)
        _, inds = max_x.topk(4)
        batch_inds = torch.arange(batch_size, device=inds.device).unsqueeze(-1)

        # new_x = torch.gather(x, 1, inds.unsqueeze(-1).expand(batch_size, 4, 4))
        new_x = x[batch_inds, inds, ...]
        # new_x = x.flatten(0, 1)[inds + batch_inds * C]
        return new_x, inds + batch_inds * C


def main():
    # models
    model = TestModel().cuda()
    x = torch.rand(1, 10, 4).cuda()

    # export onnx
    input_names = ['input']
    output_names = ['output', 'inds']
    torch.onnx.export(
        model,
        x,
        'tmp.onnx',
        input_names=input_names,
        output_names=output_names,
        dynamic_axes={'input': {
            0: 'b',
            1: 'n'
        }},
        opset_version=11)

    # export tensorrt
    input_shapes = {
        'input': {
            'min_shape': [1, 5, 4],
            'opt_shape': [2, 10, 4],
            'max_shape': [4, 40, 4]
        }
    }
    engine = from_onnx(
        'tmp.onnx', input_shapes=input_shapes, max_workspace_size=1 << 30)

    wrapper = TRTWrapper(engine)

    x = torch.rand(2, 10, 4).cuda()

    torch_out = model(x)
    out = wrapper({'input': x})
    out = [out[name] for name in output_names]

    # print(x)

    for o, to in zip(out, torch_out):
        print(o.shape)
        torch.testing.assert_allclose(o, to)

    # print(torch_out)


if __name__ == '__main__':
    main()
@zerollzeng
Copy link
Collaborator

Tried to reproduce the issue with TRT 8.4.1.5 using polygraphy:

[I] onnxrt-runner-N0-09/07/22-08:16:25  | Completed 1 iteration(s) in 0.1693 ms | Average inference time: 0.1693 ms.
[I] Accuracy Comparison | trt-runner-N0-09/07/22-08:16:25 vs. onnxrt-runner-N0-09/07/22-08:16:25
[I]     Comparing Output: 'output' (dtype=float32, shape=(2, 4, 4)) with 'output' (dtype=float32, shape=(2, 4, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/07/22-08:16:25: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]         onnxrt-runner-N0-09/07/22-08:16:25: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]         Error Metrics: output
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0), max=0 at (0, 0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0, 0), max=0 at (0, 0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[I]     Comparing Output: 'inds' (dtype=int32, shape=(2, 4)) with 'inds' (dtype=int64, shape=(2, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/07/22-08:16:25: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         onnxrt-runner-N0-09/07/22-08:16:25: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         Error Metrics: inds
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[I]     PASSED | All outputs matched | Outputs: ['output', 'inds']
[I] PASSED | Command: /home/zeroz/.local/bin/polygraphy run tmp.onnx --trt --onnxrt --trt-opt-shapes input:[2,10,4] --input-shapes input:[2,10,4]

the accuracy is matched between TRT and ONNX. can you check whether it's matched between Torch and ONNX?

@zerollzeng zerollzeng self-assigned this Sep 7, 2022
@zerollzeng zerollzeng added the triaged Issue has been triaged by maintainers label Sep 7, 2022
@grimoire
Copy link
Author

grimoire commented Sep 8, 2022

The results matched between Torch and ONNX.
Note that this error only appear when min_shape[0] != max_shape[0] and min_shape[1] != max_shape[1]

@zerollzeng
Copy link
Collaborator

I can reproduce this with

[I]         trt-runner-N0-09/09/22-00:16:40: output | Stats: mean=0.35972, std-dev=0.34652, var=0.12008, median=0.27958, min=0 at (1, 0, 0), max=0.96826 at (0, 0, 1), avg-magnitude=0.35972
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (0     , 0.0989) |         12 | ########################################
                (0.0989, 0.198 ) |          2 | ######
                (0.198 , 0.297 ) |          3 | ##########
                (0.297 , 0.396 ) |          2 | ######
                (0.396 , 0.494 ) |          3 | ##########
                (0.494 , 0.593 ) |          1 | ###
                (0.593 , 0.692 ) |          1 | ###
                (0.692 , 0.791 ) |          1 | ###
                (0.791 , 0.89  ) |          3 | ##########
                (0.89  , 0.989 ) |          4 | #############
[I]         onnxrt-runner-N0-09/09/22-00:16:40: output | Stats: mean=0.58187, std-dev=0.31482, var=0.099113, median=0.6921, min=0.039055 at (0, 2, 3), max=0.98886 at (1, 0, 0), avg-magnitude=0.58187
[I]             ---- Histogram ----
                Bin Range        |  Num Elems | Visualization
                (0     , 0.0989) |          3 | ###############
                (0.0989, 0.198 ) |          3 | ###############
                (0.198 , 0.297 ) |          2 | ##########
                (0.297 , 0.396 ) |          3 | ###############
                (0.396 , 0.494 ) |          2 | ##########
                (0.494 , 0.593 ) |          2 | ##########
                (0.593 , 0.692 ) |          1 | #####
                (0.692 , 0.791 ) |          5 | #########################
                (0.791 , 0.89  ) |          3 | ###############
                (0.89  , 0.989 ) |          8 | ########################################
[I]         Error Metrics: output
[I]             Minimum Required Tolerance: elemwise error | [abs=0.98886] OR [rel=1.1358] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0.23223, std-dev=0.32724, var=0.10709, median=0.0025968, min=0 at (0, 0, 0), max=0.98886 at (1, 0, 0), avg-magnitude=0.23223
[I]                 ---- Histogram ----
                    Bin Range        |  Num Elems | Visualization
                    (0     , 0.0989) |         18 | ########################################
                    (0.0989, 0.198 ) |          3 | ######
                    (0.198 , 0.297 ) |          3 | ######
                    (0.297 , 0.396 ) |          0 |
                    (0.396 , 0.494 ) |          1 | ##
                    (0.494 , 0.593 ) |          0 |
                    (0.593 , 0.692 ) |          1 | ##
                    (0.692 , 0.791 ) |          3 | ######
                    (0.791 , 0.89  ) |          1 | ##
                    (0.89  , 0.989 ) |          2 | ####
[I]             Relative Difference | Stats: mean=0.39215, std-dev=0.46259, var=0.21399, median=0.0028745, min=0 at (0, 0, 0), max=1.1358 at (1, 1, 3), avg-magnitude=0.39215
[I]                 ---- Histogram ----
                    Bin Range      |  Num Elems | Visualization
                    (0    , 0.114) |         17 | ########################################
                    (0.114, 0.227) |          0 |
                    (0.227, 0.341) |          2 | ####
                    (0.341, 0.454) |          1 | ##
                    (0.454, 0.568) |          0 |
                    (0.568, 0.681) |          0 |
                    (0.681, 0.795) |          1 | ##
                    (0.795, 0.909) |          1 | ##
                    (0.909, 1.02 ) |          9 | #####################
                    (1.02 , 1.14 ) |          1 | ##
[E]         FAILED | Difference exceeds tolerance (rel=1e-05, abs=1e-05)
[I]     Comparing Output: 'inds' (dtype=int32, shape=(2, 4)) with 'inds' (dtype=int64, shape=(2, 4))
[I]     Tolerance: [abs=1e-05, rel=1e-05] | Checking elemwise error
[I]         trt-runner-N0-09/09/22-00:16:40: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         onnxrt-runner-N0-09/09/22-00:16:40: inds | Stats: mean=10.375, std-dev=4.7942, var=22.984, median=9, min=5 at (0, 0), max=19 at (1, 2), avg-magnitude=10.375
[I]         Error Metrics: inds
[I]             Minimum Required Tolerance: elemwise error | [abs=0] OR [rel=0] (requirements may be lower if both abs/rel tolerances are set)
[I]             Absolute Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]             Relative Difference | Stats: mean=0, std-dev=0, var=0, median=0, min=0 at (0, 0), max=0 at (0, 0), avg-magnitude=0
[I]         PASSED | Difference is within tolerance (rel=1e-05, abs=1e-05)
[E]     FAILED | Mismatched outputs: ['output']
[!] FAILED | Command: /home/zeroz/.local/bin/polygraphy run tmp.onnx --trt --onnxrt --trt-opt-shapes input:[2,10,4] --trt-min-shapes input:[1,5,4] --trt-max-shapes input:[4,40,4] --input-shapes input:[2,10,4]

I've filed internal bug 3790543 to track this, thanks for reporting.

@zerollzeng
Copy link
Collaborator

The issue has been fixed in TRT 8.5, there will be a preview feature to fix this issue, please wait for the 8.5 release coming soon :-)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
triaged Issue has been triaged by maintainers
Projects
None yet
Development

No branches or pull requests

2 participants