Skip to content

Commit

Permalink
Add prelu and normalize ops (#107)
Browse files Browse the repository at this point in the history
  • Loading branch information
alessandropalla authored Jul 25, 2024
1 parent 0ebb47b commit f04d499
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 2 deletions.
13 changes: 13 additions & 0 deletions include/intel_npu_acceleration_library/nn_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,19 @@ class ModelFactory : public intel_npu_acceleration_library::OVInferenceModel {
return power.get();
}

/**
* @brief Create a new prelu operation
*
* @param x1 operation's input node
* @param slope operation's slope
* @return ov::op::Op*
*/
ov::op::Op* prelu(ov::op::Op* x1, ov::op::Op* slope) {
auto power = std::make_shared<ov::op::v0::PRelu>(x1->output(0), slope->output(0));
operations.push_back(power);
return power.get();
}

/**
* @brief Create a new log softmax operation
*
Expand Down
1 change: 1 addition & 0 deletions intel_npu_acceleration_library/backend/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ def get_supported_ops() -> List[SupportedOp]:
SupportedOp(name="log_act", inputs=1),
SupportedOp(name="negative", inputs=1),
SupportedOp(name="relu", inputs=1),
SupportedOp(name="prelu", inputs=2),
SupportedOp(name="sigmoid", inputs=1),
SupportedOp(name="sign", inputs=1),
SupportedOp(name="sin_act", inputs=1),
Expand Down
56 changes: 54 additions & 2 deletions intel_npu_acceleration_library/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,15 +443,45 @@ def layer_norm(
"""
axis = input.shape.index(normalized_shape[0])
ln = generate_op([input], "normL2", axis, eps)
if weight:
if weight is not None:
ln = ln * weight

if bias:
if bias is not None:
ln = ln + bias

return ln


@implements(torch.nn.functional.normalize)
def normalize(
input: Tensor,
p: float = 2.0,
dim: int = 1,
eps: float = 1e-12,
out: Optional[Tensor] = None,
) -> Tensor:
"""Return the normalized tensor.
Args:
input (Tensor): The input tensor.
p (float): The power value. Defaults to 2.0.
dim (int): The dim to normalize. Defaults to 1.
eps (float): The epsilon value. Defaults to 1e-12.
out (Optional[Tensor], optional): Output tensor. Defaults to None.
Raises:
NotImplementedError: p != 2 is not supported yet
Returns:
Tensor: Output tensor.
"""
if p != 2:
raise NotImplementedError("p != 2 is not supported yet")

out = generate_op([input], "normL2", dim, eps)
return out


@implements(torch.ceil)
def ceil(x: Tensor, out: Optional[Tensor] = None) -> Tensor:
"""Return the ceil of a tensor element-wise.
Expand Down Expand Up @@ -814,6 +844,20 @@ def relu(x: Tensor, inplace=False) -> Tensor:
return out


@implements(torch.nn.functional.prelu)
def prelu(x: Tensor, weight: Tensor) -> Tensor:
"""Return the parametric relu of a tensor element-wise.
Args:
x (Tensor): The input tensor.
weight (Tensor): The weights tensor.
Returns:
Tensor: Output tensor.
"""
return generate_op([x, weight], "prelu")


@implements(torch.nn.functional.sigmoid)
def sigmoid(x: Tensor) -> Tensor:
"""Return the sigmoid of a tensor element-wise.
Expand Down Expand Up @@ -955,6 +999,10 @@ def adaptive_avg_pool2d(input: Tensor, output_size: Sequence[int]):
Returns:
Tensor: Output tensor.
"""
if output_size == 1:
return generate_op(
[input], "reduce_mean", reduction_axes=[-2, -1], keep_dims=True
)
return generate_op([input, output_size], "adaptive_avg_pool")


Expand All @@ -977,6 +1025,10 @@ def adaptive_max_pool2d(
"""
if return_indices:
raise NotImplementedError("return_indices is not supported yet")
if output_size == 1:
return generate_op(
[input], "reduce_max", reduction_axes=[-2, -1], keep_dims=True
)
return generate_op([input, output_size], "adaptive_max_pool")


Expand Down
5 changes: 5 additions & 0 deletions src/bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,11 @@ intel_npu_acceleration_library_DLL_API ov::op::Op* relu(intel_npu_acceleration_l
return factory->relu(in0);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* prelu(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* in0, ov::op::Op* in1) {
return factory->prelu(in0, in1);
}

intel_npu_acceleration_library_DLL_API ov::op::Op* sigmoid(intel_npu_acceleration_library::ModelFactory* factory,
ov::op::Op* in0) {
return factory->sigmoid(in0);
Expand Down

0 comments on commit f04d499

Please sign in to comment.