Skip to content

Commit

Permalink
Adding support for L2 normalisation operation (#65)
Browse files Browse the repository at this point in the history
* Adding support for L2 normalisation operation

* Fixing L2 norm operation and tests

* Fix for L2 norm and tests

* Add default param to L2 norm and test fix

---------

Co-authored-by: SarahByrneIntel <sarahbyrne@intel.com>
  • Loading branch information
SarahByrneIntel and SarahByrneIntel authored Jun 21, 2024
1 parent 417bd76 commit e4d0d61
Show file tree
Hide file tree
Showing 5 changed files with 67 additions and 1 deletion.
15 changes: 15 additions & 0 deletions include/intel_npu_acceleration_library/nn_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,21 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel {
return sdpa.get();
}

/**
* @brief Create a new L2 normalization operation
*
* @param data operation's input node
* @param axes node indicating axes along which reduction is calculated
* @param eps the epsilon added to L2 norm
* @return ov::op::Op*
*/
ov::op::Op* normL2(ov::op::Op* data, ov::op::Op* axes, float eps) {
auto normL2 =
std::make_shared<ov::opset1::NormalizeL2>(data->output(0), axes->output(0), eps, ov::op::EpsMode::MAX);
operations.push_back(normL2);
return normL2.get();
}

/**
* @brief Compile the model
*
Expand Down
19 changes: 19 additions & 0 deletions intel_npu_acceleration_library/backend/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,25 @@ def slice(
end_mask_ptr,
)

@return_tensor
def normL2(
self, input_node: ctypes._Pointer, axis: int, eps: Optional[float] = 1e-12
) -> ctypes._Pointer:
"""Generate an L2 normalization layer.
Args:
input_node (ctypes._Pointer): layer input node
axis (int): axis
eps (float): epsilon added to L2 norm
Returns:
ctypes._Pointer: output node
"""
if axis < 0:
axis = abs(axis)
axis_node = self.constant(axis).node # type: ignore
return backend_lib.normL2(self._mm, input_node, axis_node, eps)

def get_output_tensor_shape(self):
"""Get output tensor shape.
Expand Down
5 changes: 5 additions & 0 deletions intel_npu_acceleration_library/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,11 @@ def get_supported_ops() -> List[SupportedOp]:
inputs=4,
parameters=[ctypes.c_bool],
),
SupportedOp(
name="normL2",
inputs=2,
parameters=[ctypes.c_float],
),
SupportedOp(
name="gather",
inputs=3,
Expand Down
5 changes: 5 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -465,4 +465,9 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* scaled_dot_product_attention(
ov::op::Op* attn_mask, bool is_causal) {
return factory->scaled_dot_product_attention(query, key, value, attn_mask, is_causal);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* normL2(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* data, ov::op::Op* axes, float eps) {
return factory->normL2(data, axes, eps);
}
}
24 changes: 23 additions & 1 deletion test/python/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import numpy as np
import pytest
import torch
import itertools


class MLP_PT(torch.nn.Module):
Expand Down Expand Up @@ -313,3 +312,26 @@ def test_constant(batch, hidden_dim):
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"

assert 1 - r2_score(reference, out) < 0.001


@pytest.mark.parametrize("batch", [16, 128])
@pytest.mark.parametrize("hidden_dim", [256, 512])
@pytest.mark.parametrize("axis", [0, 1])
def test_normalisation(batch, hidden_dim, axis):

X = torch.rand((batch, hidden_dim)).to(torch.float16) - 0.5

model = NNFactory()
input = model.parameter(X.shape)
output = model.normL2(input, axis)
model.compile(output)
out = model.run(X.numpy())

reference = torch.nn.functional.normalize(X, p=2.0, dim=axis).numpy()
print(out)
print(reference)
assert out.shape == reference.shape, "Output shape mismatch"
assert np.isfinite(reference).all(), "Pytorch Reference contains NaN or Inf"
assert np.isfinite(out).all(), "NPU output contains NaN or Inf"

assert 1 - r2_score(reference, out) < 0.001

0 comments on commit e4d0d61

Please sign in to comment.