Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mish op to ngraph #1187

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions ngraph/src/ngraph/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,8 @@ set (SRC
op/min.hpp
op/minimum.cpp
op/minimum.hpp
op/mish.cpp
op/mish.hpp
op/multiply.cpp
op/multiply.hpp
op/negative.cpp
Expand Down
81 changes: 81 additions & 0 deletions ngraph/src/ngraph/op/mish.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "ngraph/op/mish.hpp"
#include "ngraph/attribute_visitor.hpp"

#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/runtime/reference/mish.hpp"

using namespace std;
using namespace ngraph;

constexpr NodeTypeInfo op::v4::Mish::type_info;

op::v4::Mish::Mish(const Output<Node>& arg)
: Op({arg})
{
constructor_validate_and_infer_types();
}

bool op::v4::Mish::visit_attributes(AttributeVisitor& visitor)
{
return true;
}

void op::v4::Mish::validate_and_infer_types()
{
set_output_size(1);
set_output_type(0, get_input_element_type(0), get_input_partial_shape(0));
}

shared_ptr<Node> op::v4::Mish::clone_with_new_inputs(const OutputVector& new_args) const
{
check_new_args_count(this, new_args);
return make_shared<Mish>(new_args.at(0));
}

namespace
{
template <element::Type_t ET>
inline bool evaluate(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
{
using T = typename element_type_traits<ET>::value_type;
runtime::reference::mish<T>(arg0->get_data_ptr<ET>(), out->get_data_ptr<ET>(), count);
return true;
}

bool evaluate_mish(const HostTensorPtr& arg0, const HostTensorPtr& out, const size_t count)
{
bool rc = true;
out->set_unary(arg0);

switch (arg0->get_element_type())
{
TYPE_CASE(f16)(arg0, out, count);
break;
TYPE_CASE(f32)(arg0, out, count);
break;
default: rc = false; break;
}
return rc;
}
}

bool op::v4::Mish::evaluate(const HostTensorVector& outputs, const HostTensorVector& inputs)
{
return evaluate_mish(inputs[0], outputs[0], shape_size(get_output_shape(0)));
}
52 changes: 52 additions & 0 deletions ngraph/src/ngraph/op/mish.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include "ngraph/node.hpp"
#include "ngraph/op/op.hpp"

namespace ngraph
{
namespace op
{
namespace v4
{
/// \brief A Self Regularized Non-Monotonic Neural Activation Function
/// f(x) = x * tanh(log(exp(x) + 1.))
///
class NGRAPH_API Mish : public ngraph::op::Op
{
public:
static constexpr NodeTypeInfo type_info{"Mish", 4};
const NodeTypeInfo& get_type_info() const override { return type_info; }
Mish() = default;
/// \brief Constructs an Mish operation.
///
/// \param data Input tensor
Mish(const Output<Node>& arg);
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;

virtual std::shared_ptr<Node>
clone_with_new_inputs(const OutputVector& new_args) const override;

bool evaluate(const HostTensorVector& outputs,
const HostTensorVector& inputs) override;
};
}
}
}
1 change: 1 addition & 0 deletions ngraph/src/ngraph/ops.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
#include "ngraph/op/maximum.hpp"
#include "ngraph/op/min.hpp"
#include "ngraph/op/minimum.hpp"
#include "ngraph/op/mish.hpp"
#include "ngraph/op/multiply.hpp"
#include "ngraph/op/negative.hpp"
#include "ngraph/op/non_max_suppression.hpp"
Expand Down
3 changes: 2 additions & 1 deletion ngraph/src/ngraph/opsets/opset4_tbl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

NGRAPH_OP(Abs, ngraph::op::v0)
NGRAPH_OP(Acos, ngraph::op::v0)
NGRAPH_OP(Acosh, ngraph::op::v3)
NGRAPH_OP(Add, ngraph::op::v1)
NGRAPH_OP(Asin, ngraph::op::v0)
NGRAPH_OP(Atan, ngraph::op::v0)
Expand Down Expand Up @@ -153,4 +154,4 @@ NGRAPH_OP(TopK, ngraph::op::v3)

// New operations added in opset4
NGRAPH_OP(NonMaxSuppression, ngraph::op::v4)
NGRAPH_OP(Acosh, ngraph::op::v3)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you remove this line?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line was moved to new place (in alphabet order).

NGRAPH_OP(Mish, ngraph::op::v4)
38 changes: 38 additions & 0 deletions ngraph/src/ngraph/runtime/reference/mish.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#pragma once

#include <cmath>
#include <cstddef>

namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename T>
void mish(const T* arg, T* out, size_t count)
{
for (size_t i = 0; i < count; i++)
{
out[i] = arg[i] * std::tanh(std::log((std::exp(arg[i]) + 1.0)));
}
}
}
}
}
2 changes: 2 additions & 0 deletions ngraph/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ set(SRC
nop_elimination.cpp
op.cpp
op_eval/matmul.cpp
op_eval/mish.cpp
op_eval/non_zero.cpp
op_eval/strided_slice.cpp
op_is.cpp
Expand Down Expand Up @@ -148,6 +149,7 @@ set(SRC
type_prop/lstm_sequence.cpp
type_prop/matmul.cpp
type_prop/max_pool.cpp
type_prop/mish.cpp
type_prop/mvn.cpp
type_prop/non_max_suppression.cpp
type_prop/non_zero.cpp
Expand Down
51 changes: 51 additions & 0 deletions ngraph/test/op_eval/mish.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include <string>
#include <vector>

#include "gtest/gtest.h"

#include "ngraph/op/mish.hpp"
#include "ngraph/runtime/host_tensor.hpp"
#include "ngraph/validation_util.hpp"
#include "runtime/backend.hpp"
#include "util/test_tools.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

TEST(op_eval, mish_0D)
{
auto p = make_shared<op::Parameter>(element::f32, Shape{});
auto mish = make_shared<op::v4::Mish>(p);
auto fun = make_shared<Function>(OutputVector{mish}, ParameterVector{p});

std::vector<std::vector<float>> inputs{{-1.0}, {1.0}, {20.0}};
std::vector<std::vector<float>> expected_result{{-0.303401}, {0.86509835720062256}, {20.0}};

for (size_t i = 0; i < inputs.size(); i++)
{
auto result = make_shared<HostTensor>();
ASSERT_TRUE(
fun->evaluate({result}, {make_host_tensor<element::Type_t::f32>(Shape{}, inputs[i])}));
EXPECT_EQ(result->get_element_type(), element::f32);
EXPECT_EQ(result->get_shape(), (Shape{}));
auto result_data = read_vector<float>(result);
EXPECT_NEAR(result_data[0], expected_result[i][0], 0.3);
iimironov marked this conversation as resolved.
Show resolved Hide resolved
}
}
54 changes: 54 additions & 0 deletions ngraph/test/type_prop/mish.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//*****************************************************************************
// Copyright 2017-2020 Intel Corporation
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//*****************************************************************************

#include "gtest/gtest.h"
#include "ngraph/ngraph.hpp"
#include "util/type_prop.hpp"

using namespace std;
using namespace ngraph;

TEST(type_prop, mish)
{
auto data = make_shared<op::Parameter>(element::f32, Shape{1, 3, 6});
auto mish_func = make_shared<op::v4::Mish>(data);
EXPECT_EQ(mish_func->get_element_type(), element::f32);
EXPECT_EQ(mish_func->get_shape(), (Shape{1, 3, 6}));
}

TEST(type_prop, mish_partial)
iimironov marked this conversation as resolved.
Show resolved Hide resolved
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto mish_func = make_shared<op::v4::Mish>(data);
EXPECT_EQ(mish_func->get_element_type(), element::f32);
ASSERT_TRUE(mish_func->get_output_partial_shape(0).same_scheme(
(PartialShape{1, Dimension::dynamic(), 6})));

// rank unknown
auto mish_partial = make_shared<op::v4::Mish>(
make_shared<op::Parameter>(element::f32, PartialShape::dynamic()));
ASSERT_TRUE(mish_partial->get_output_partial_shape(0).same_scheme(PartialShape::dynamic()));
}

TEST(type_prop, mish_partial_static_rank)
{
auto data = make_shared<op::Parameter>(element::f32, PartialShape{1, Dimension::dynamic(), 6});
auto mish_func = make_shared<op::v4::Mish>(data);
EXPECT_EQ(mish_func->get_element_type(), element::f32);
ASSERT_TRUE(mish_func->get_output_partial_shape(0).same_scheme(
(PartialShape{1, Dimension::dynamic(), 6})));
ASSERT_TRUE(mish_func->get_output_partial_shape(0).rank().is_static());
}