Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

High Numerical Errors in Mish Activation with FLOAT16 Precision on Neural Engine #2359

Open
ChinChangYang opened this issue Oct 5, 2024 · 3 comments
Labels
bug Unexpected behaviour that should be corrected (type)

Comments

@ChinChangYang
Copy link
Contributor

🐞Describing the bug

The built-in Mish activation function in coremltools introduces significant numerical errors in Core ML models when using 16-bit floating point precision (FLOAT16) on configurations with ComputeUnit=CPU_AND_NE. Specifically, converting models that utilize the Mish activation results in substantial discrepancies in output predictions compared to the original model, leading to high error rates across various metrics.

Stack Trace

N/A

To Reproduce

Follow the steps below to reproduce the high numerical errors using the built-in Mish activation function:

  1. Clone the KataGo repository:

    git clone --branch v1.15.3-coreml1 https://github.com/ChinChangYang/KataGo.git KataGo-v1.15.3-coreml1
    cd KataGo-v1.15.3-coreml1/python
  2. Download a KataGo model in RAW checkpoint format:

    wget https://media.katagotraining.org/uploaded/networks/zips/kata1/kata1-b18c384nbt-s9996604416-d4316597426.zip
    unzip kata1-b18c384nbt-s9996604416-d4316597426.zip
    ln -s kata1-b18c384nbt-s9996604416-d4316597426/model.ckpt model.ckpt
  3. Install Python Modules:

    pip install torch coremltools matplotlib
  4. Evaluate the high error using the built-in Mish implementation:

    wget https://gist.githubusercontent.com/ChinChangYang/529ccdffb90b60d307550b067f2fbab8/raw/abc3050cfad77e1ec87c92f61bd4b8c1b4f6cc28/testcoremlerror_original.py
    python testcoremlerror_original.py

    Expected Output:

    Mean Absolute Errors Across Samples:
      var_2572:
        FLOAT16: 1.042287
        FLOAT32: 0.000095
      linear_9:
        FLOAT16: 3.587491
        FLOAT32: 0.000245
      linear_10:
        FLOAT16: 2.812497
        FLOAT32: 0.000182
      linear_11:
        FLOAT16: 2.498940
        FLOAT32: 0.000269
      var_2631:
        FLOAT16: 0.079012
        FLOAT32: 0.000011
    
  5. Evaluate the lower error using the alternative Mish implementation:

    wget https://gist.githubusercontent.com/ChinChangYang/b9d45f13a40ff738baa607a265a0b2c3/raw/8bf3ae8e66946451be7dbd0d6debdae9d8e82fcf/testcoremlerror_workaround.py
    python testcoremlerror_workaround.py

    Expected Output:

    Mean Absolute Errors Across Samples:
      var_2572:
        FLOAT16: 0.008898
        FLOAT32: 0.000395
      linear_9:
        FLOAT16: 0.018509
        FLOAT32: 0.000812
      linear_10:
        FLOAT16: 0.014011
        FLOAT32: 0.000628
      linear_11:
        FLOAT16: 0.016918
        FLOAT32: 0.000859
      var_2631:
        FLOAT16: 0.001414
        FLOAT32: 0.000036
    

System environment (please complete the following information):

  • coremltools version: 8.0
  • OS: MacOS 15.0
  • Any other relevant version information:
    • PyTorch: 2.4.1

Additional context

The issue arises specifically when using ComputeUnit=CPU_AND_NE with Precision=FLOAT16. The built-in Mish activation function in coremltools leads to high numerical errors, as evidenced by metrics such as winrateError, leadError, and others showing discrepancies upwards of 25%. Switching to an alternative Mish implementation drastically reduces these errors to below 1%, albeit with a 32% increase in inference time due to the additional operators introduced.

This problem is isolated to 16-bit floating point precision on the Neural Engine (NE), as experiments with other compute units and precision settings (e.g., FLOAT32) do not exhibit the same high error rates. The significant reduction in error using the alternative Mish implementation suggests that the built-in Mish operator may have implementation issues when used in this specific configuration.

This issue was generated based on a detailed analysis of numerical errors in Core ML models using the Mish activation function with 16-bit precision, as documented in the related blog post. Further investigation and collaboration from the coremltools engineering team would be greatly appreciated to resolve this matter.

@ChinChangYang ChinChangYang added the bug Unexpected behaviour that should be corrected (type) label Oct 5, 2024
@ChinChangYang
Copy link
Contributor Author

I write the alternative Mish implementation here:

def mish_torch_sigmoid(context, node):
    inputs = _get_inputs(context, node, expected=1)
    x = inputs[0]

    threshold = 10.39

    # Approximating conditional behavior using sigmoid function
    sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))
    
    # Approximate implementation of Softplus
    softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
    softplus = mb.add(x=mb.mul(x=x, y=sigmoid_threshold), 
                      y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)))

    # Mish(x) = x * tanh(Softplus(x))
    tanh_softplus = mb.tanh(x=softplus)
    res = mb.mul(x=x, y=tanh_softplus, name=node.name)
    context.add(res)

@TobyRoseman
Copy link
Collaborator

For security reasons, I am not able to download and run your network. Please create a minimal example to demonstrate the issue. Ideally some small amount of self contained code that I can just copy and paste.

@ChinChangYang
Copy link
Contributor Author

For security reasons, I am not able to download and run your network. Please create a minimal example to demonstrate the issue. Ideally some small amount of self contained code that I can just copy and paste.

I have created a minimal example to demonstrate the issue below. It is small amount of self contained code so you can just copy and paste.

To Reproduce

Two scripts to reproduce this issue. One uses the built-in Mish activation, and the other uses the alternative Mish implementation.

Built-in Mish Activation

import torch
import torch.nn as nn
import coremltools as ct
import numpy as np

from coremltools.converters.mil.frontend.torch.torch_op_registry import (
    _TORCH_OPS_REGISTRY,
    register_torch_op,
)
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil import Builder as mb

# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE FOLLOWING LINES OF CODE

# # Remove the original mish function
# if "mish" in _TORCH_OPS_REGISTRY:
#     del _TORCH_OPS_REGISTRY["mish"]


# # Register the new mish function
# @register_torch_op
# def mish(context, node):
#     inputs = _get_inputs(context, node, expected=1)
#     x = inputs[0]

#     threshold = 10.39

#     # Approximating conditional behavior using sigmoid function
#     sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))

#     # Approximate implementation of Softplus
#     softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
#     softplus = mb.add(
#         x=mb.mul(x=x, y=sigmoid_threshold),
#         y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)),
#     )

#     # Mish(x) = x * tanh(Softplus(x))
#     tanh_softplus = mb.tanh(x=softplus)
#     res = mb.mul(x=x, y=tanh_softplus, name=node.name)
#     context.add(res)


# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE ABOVE LINES OF CODE


class MishModel(nn.Module):
    def __init__(self):
        super(MishModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
        self.act = nn.Mish()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28 * 16, 10)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.flatten(x)
        x = self.fc1(x)
        return x


# Export to Core ML
def export_to_coreml(model, target_path, compute_units):
    dummy_input = torch.randn(1, 1, 28, 28)

    with torch.no_grad():
        model.eval()
        traced_model = torch.jit.trace(model, dummy_input)

    inputs = [ct.TensorType(shape=tuple(dummy_input.shape))]

    mlmodel = ct.convert(
        traced_model,
        inputs=inputs,
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS15,
        compute_units=compute_units,
    )

    mlmodel.save(target_path)
    return mlmodel


def generate_random_inputs(model, batch_size=1):
    inputs = torch.randint(
        low=0,
        high=40,
        size=(
            batch_size,
            1,
            28,
            28,
        ),
        dtype=torch.float32,
    )

    return inputs


def get_coreml_outputs(mlmodel, inputs):
    try:
        predictions = mlmodel.predict(inputs)
        return predictions
    except Exception as e:
        print(f"Error during CoreML model prediction: {e}")
        raise


def flatten_outputs(outputs):
    flattened = []
    if isinstance(outputs, torch.Tensor):
        flattened.append(outputs)
    elif isinstance(outputs, (tuple, list)):
        for item in outputs:
            flattened.extend(flatten_outputs(item))
    else:
        raise TypeError(f"Unsupported output type: {type(outputs)}")
    return flattened


def compute_error(torch_outputs, coreml_outputs, output_names):
    errors = {}
    flattened_torch_outputs = flatten_outputs(torch_outputs)

    for idx, torch_output in enumerate(flattened_torch_outputs):
        torch_np = torch_output.cpu().numpy()
        coreml_key = output_names[idx]
        coreml_np = coreml_outputs[coreml_key]
        error = np.mean(np.abs(torch_np - coreml_np))
        errors[coreml_key] = error

    return errors


# Main function
def main():
    model = MishModel()

    coreml_model_gpu = export_to_coreml(
        model,
        "model_fp32.mlpackage",
        compute_units=ct.ComputeUnit.CPU_AND_GPU,
    )

    coreml_model_ne = export_to_coreml(
        model,
        "model_fp16.mlpackage",
        compute_units=ct.ComputeUnit.CPU_AND_NE,
    )

    spec = coreml_model_ne._spec
    input_names = [input.name for input in spec.description.input]
    output_names = [output.name for output in spec.description.output]
    input_name = input_names[0]
    num_samples = 30
    test_inputs = generate_random_inputs(model, batch_size=num_samples)
    errors_ne = {}
    errors_gpu = {}
    
    for name in output_names:
        errors_ne[name] = []
        errors_gpu[name] = []

    # Iterate over each sample
    for i in range(num_samples):
        # Prepare single sample inputs
        single_input = test_inputs[i].unsqueeze(0)  # Shape: (1, C, H, W)

        # Prepare input dictionary for CoreML prediction
        input_dict = {
            input_name: single_input.numpy(),
        }

        # Compute PyTorch outputs
        with torch.no_grad():
            torch_output = model(single_input)

        # Ensure torch_output is a tuple
        if not isinstance(torch_output, tuple):
            torch_output = (torch_output,)

        # Compute CoreML outputs
        coreml_output_ne = get_coreml_outputs(coreml_model_ne, input_dict)
        coreml_output_gpu = get_coreml_outputs(coreml_model_gpu, input_dict)

        # Compute errors for each output# Compute errors for each output
        error_current_ne = compute_error(torch_output, coreml_output_ne, output_names)
        error_current_gpu = compute_error(torch_output, coreml_output_gpu, output_names)

        # Accumulate errors
        for name in output_names:
            if error_current_ne.get(name) is not None:
                errors_ne[name].append(error_current_ne[name])
            if error_current_gpu.get(name) is not None:
                errors_gpu[name].append(error_current_gpu[name])

    # Compute mean errors across all samples
    mean_errors_ne = {name: np.mean(errors_ne[name]) for name in output_names}
    mean_errors_gpu = {name: np.mean(errors_gpu[name]) for name in output_names}

    # Display mean errors
    print("\nMean Absolute Errors Across Samples:")
    for output_name in output_names:
        ne_error = mean_errors_ne[output_name]
        gpu_error = mean_errors_gpu[output_name]
        print(f"  {output_name}:")
        print(f"    NE:  {ne_error:.6f}")
        print(f"    GPU: {gpu_error:.6f}")


if __name__ == "__main__":
    main()

Expected Output

Converting PyTorch Frontend ==> MIL Ops:  90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 9/10 [00:00<00:00, 2181.88 ops/s]
Running MIL frontend_pytorch pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 5500.01 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 88/88 [00:00<00:00, 3854.05 passes/s]
Running MIL backend_mlprogram pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 9894.17 passes/s]
Converting PyTorch Frontend ==> MIL Ops:  90%|█████████████████████████████████████████████████████████████████████████████████████▌         | 9/10 [00:00<00:00, 5401.94 ops/s]
Running MIL frontend_pytorch pipeline: 100%|███████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 7162.40 passes/s]
Running MIL default pipeline: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 88/88 [00:00<00:00, 4313.92 passes/s]
Running MIL backend_mlprogram pipeline: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:00<00:00, 9607.11 passes/s]

Mean Absolute Errors Across Samples:
  linear_0:
    NE:  3.951386
    GPU: 0.001291

Alternative Mish Implementation

import torch
import torch.nn as nn
import coremltools as ct
import numpy as np

from coremltools.converters.mil.frontend.torch.torch_op_registry import (
    _TORCH_OPS_REGISTRY,
    register_torch_op,
)
from coremltools.converters.mil.frontend.torch.ops import _get_inputs
from coremltools.converters.mil import Builder as mb

# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE FOLLOWING LINES OF CODE

# Remove the original mish function
if "mish" in _TORCH_OPS_REGISTRY:
    del _TORCH_OPS_REGISTRY["mish"]


# Register the new mish function
@register_torch_op
def mish(context, node):
    inputs = _get_inputs(context, node, expected=1)
    x = inputs[0]

    threshold = 10.39

    # Approximating conditional behavior using sigmoid function
    sigmoid_threshold = mb.sigmoid(x=mb.sub(x=x, y=threshold))

    # Approximate implementation of Softplus
    softplus_part = mb.softplus(x=mb.minimum(x=x, y=threshold))
    softplus = mb.add(
        x=mb.mul(x=x, y=sigmoid_threshold),
        y=mb.mul(x=softplus_part, y=mb.sub(x=1.0, y=sigmoid_threshold)),
    )

    # Mish(x) = x * tanh(Softplus(x))
    tanh_softplus = mb.tanh(x=softplus)
    res = mb.mul(x=x, y=tanh_softplus, name=node.name)
    context.add(res)


# TO ENABLE THE WORKAROUND MISH FUNCTION, UNCOMMENT THE ABOVE LINES OF CODE


class MishModel(nn.Module):
    def __init__(self):
        super(MishModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, kernel_size=3, padding="same")
        self.act = nn.Mish()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28 * 28 * 16, 10)

    def forward(self, x):
        x = self.act(self.conv1(x))
        x = self.flatten(x)
        x = self.fc1(x)
        return x


# Export to Core ML
def export_to_coreml(model, target_path, compute_units):
    dummy_input = torch.randn(1, 1, 28, 28)

    with torch.no_grad():
        model.eval()
        traced_model = torch.jit.trace(model, dummy_input)

    inputs = [ct.TensorType(shape=tuple(dummy_input.shape))]

    mlmodel = ct.convert(
        traced_model,
        inputs=inputs,
        compute_precision=ct.precision.FLOAT16,
        minimum_deployment_target=ct.target.iOS15,
        compute_units=compute_units,
    )

    mlmodel.save(target_path)
    return mlmodel


def generate_random_inputs(model, batch_size=1):
    inputs = torch.randint(
        low=0,
        high=40,
        size=(
            batch_size,
            1,
            28,
            28,
        ),
        dtype=torch.float32,
    )

    return inputs


def get_coreml_outputs(mlmodel, inputs):
    try:
        predictions = mlmodel.predict(inputs)
        return predictions
    except Exception as e:
        print(f"Error during CoreML model prediction: {e}")
        raise


def flatten_outputs(outputs):
    flattened = []
    if isinstance(outputs, torch.Tensor):
        flattened.append(outputs)
    elif isinstance(outputs, (tuple, list)):
        for item in outputs:
            flattened.extend(flatten_outputs(item))
    else:
        raise TypeError(f"Unsupported output type: {type(outputs)}")
    return flattened


def compute_error(torch_outputs, coreml_outputs, output_names):
    errors = {}
    flattened_torch_outputs = flatten_outputs(torch_outputs)

    for idx, torch_output in enumerate(flattened_torch_outputs):
        torch_np = torch_output.cpu().numpy()
        coreml_key = output_names[idx]
        coreml_np = coreml_outputs[coreml_key]
        error = np.mean(np.abs(torch_np - coreml_np))
        errors[coreml_key] = error

    return errors


# Main function
def main():
    model = MishModel()

    coreml_model_gpu = export_to_coreml(
        model,
        "model_fp32.mlpackage",
        compute_units=ct.ComputeUnit.CPU_AND_GPU,
    )

    coreml_model_ne = export_to_coreml(
        model,
        "model_fp16.mlpackage",
        compute_units=ct.ComputeUnit.CPU_AND_NE,
    )

    spec = coreml_model_ne._spec
    input_names = [input.name for input in spec.description.input]
    output_names = [output.name for output in spec.description.output]
    input_name = input_names[0]
    num_samples = 30
    test_inputs = generate_random_inputs(model, batch_size=num_samples)
    errors_ne = {}
    errors_gpu = {}
    
    for name in output_names:
        errors_ne[name] = []
        errors_gpu[name] = []

    # Iterate over each sample
    for i in range(num_samples):
        # Prepare single sample inputs
        single_input = test_inputs[i].unsqueeze(0)  # Shape: (1, C, H, W)

        # Prepare input dictionary for CoreML prediction
        input_dict = {
            input_name: single_input.numpy(),
        }

        # Compute PyTorch outputs
        with torch.no_grad():
            torch_output = model(single_input)

        # Ensure torch_output is a tuple
        if not isinstance(torch_output, tuple):
            torch_output = (torch_output,)

        # Compute CoreML outputs
        coreml_output_ne = get_coreml_outputs(coreml_model_ne, input_dict)
        coreml_output_gpu = get_coreml_outputs(coreml_model_gpu, input_dict)

        # Compute errors for each output# Compute errors for each output
        error_current_ne = compute_error(torch_output, coreml_output_ne, output_names)
        error_current_gpu = compute_error(torch_output, coreml_output_gpu, output_names)

        # Accumulate errors
        for name in output_names:
            if error_current_ne.get(name) is not None:
                errors_ne[name].append(error_current_ne[name])
            if error_current_gpu.get(name) is not None:
                errors_gpu[name].append(error_current_gpu[name])

    # Compute mean errors across all samples
    mean_errors_ne = {name: np.mean(errors_ne[name]) for name in output_names}
    mean_errors_gpu = {name: np.mean(errors_gpu[name]) for name in output_names}

    # Display mean errors
    print("\nMean Absolute Errors Across Samples:")
    for output_name in output_names:
        ne_error = mean_errors_ne[output_name]
        gpu_error = mean_errors_gpu[output_name]
        print(f"  {output_name}:")
        print(f"    NE:  {ne_error:.6f}")
        print(f"    GPU: {gpu_error:.6f}")


if __name__ == "__main__":
    main()

Expected Output

Converting PyTorch Frontend ==> MIL Ops:  90%|█████████████████████████████████████▊    | 9/10 [00:00<00:00, 1975.03 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████| 5/5 [00:00<00:00, 4351.84 passes/s]
Running MIL default pipeline: 100%|█████████████████████████████████████████████████| 88/88 [00:00<00:00, 2862.23 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████| 12/12 [00:00<00:00, 6956.69 passes/s]
Converting PyTorch Frontend ==> MIL Ops:  90%|█████████████████████████████████████▊    | 9/10 [00:00<00:00, 4345.93 ops/s]
Running MIL frontend_pytorch pipeline: 100%|██████████████████████████████████████████| 5/5 [00:00<00:00, 5482.75 passes/s]
Running MIL default pipeline: 100%|█████████████████████████████████████████████████| 88/88 [00:00<00:00, 3141.61 passes/s]
Running MIL backend_mlprogram pipeline: 100%|███████████████████████████████████████| 12/12 [00:00<00:00, 7155.48 passes/s]

Mean Absolute Errors Across Samples:
  linear_0:
    NE:  0.001768
    GPU: 0.001291

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Unexpected behaviour that should be corrected (type)
Projects
None yet
Development

No branches or pull requests

2 participants