Skip to content

Commit

Permalink
Add math ops mul, sum, softmax (PaddlePaddle#52)
Browse files Browse the repository at this point in the history
* add mul, sum, softmax

* fix sum, add tests

* add CopyOpAttr()
  • Loading branch information
gglin001 authored Aug 11, 2021
1 parent 453ca93 commit fe0212c
Show file tree
Hide file tree
Showing 8 changed files with 316 additions and 7 deletions.
27 changes: 22 additions & 5 deletions paddle/fluid/framework/ipu/ipu_backend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,11 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result = builder_->aiOnnxOpset11().conv(
inputs, dilations, group, {}, pads, strides);
tensors_.emplace(outputs[0], result);
} else if (op_type == "MatMul") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().matmul(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "ReduceMean") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand All @@ -338,6 +343,17 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
popart::TensorId result =
builder_->aiOnnxOpset11().reducemean(inputs, axes, keepdims);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Softmax") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
auto axis = BOOST_GET_CONST(int64_t, op->GetAttr("axis"));
popart::TensorId result = builder_->aiOnnxOpset11().softmax(inputs, axis);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Sum") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
popart::TensorId result = builder_->aiOnnxOpset11().sum(inputs);
tensors_.emplace(outputs[0], result);
} else if (op_type == "Pow") {
auto inputs = GetOpInputs(op);
auto outputs = op->Output("__outputs__");
Expand Down Expand Up @@ -365,13 +381,14 @@ void IpuBackend::LowerBody(const ir::Graph* graph) {
auto outputs = op->Output("__outputs__");
auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon"));
auto groups = BOOST_GET_CONST(int64_t, op->GetAttr("groups"));
std::vector<popart::TensorId> result =
builder_->aiGraphcoreOpset1().groupnormalization(inputs, groups, epsilon);
//Y
std::vector<popart::TensorId> result =
builder_->aiGraphcoreOpset1().groupnormalization(inputs, groups,
epsilon);
// Y
tensors_.emplace(outputs[0], result[0]);
//Mean
// Mean
tensors_.emplace(outputs[1], result[1]);
//Variance
// Variance
tensors_.emplace(outputs[2], result[2]);
} else if (op_type == "MaxPool") {
auto inputs = GetOpInputs(op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ namespace paddle {
namespace framework {
namespace ipu {


// This avoids the static initialisation order fiasco,
std::unordered_map<std::string, SymbolHandler> &SymbolHandlers() {
static std::unordered_map<std::string, SymbolHandler> symbol_handlers;
Expand Down Expand Up @@ -80,6 +79,17 @@ void ConnectNodes(ir::Node *first_node, ir::Node *next_node) {
next_node->inputs.push_back(first_node);
}

void CopyOpAttr(std::string attr_name, OpDesc *op, OpDesc *new_op,
bool override) {
if (new_op->HasAttr(attr_name) && !override) {
return;
}
if (op->HasAttr(attr_name)) {
new_op->SetAttr(attr_name, op->GetAttr(attr_name));
new_op->Flush();
}
}

int ConvertDataType(int type) {
auto dtype = static_cast<proto::VarType::Type>(type);
switch (dtype) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
#include <string>
#include <vector>

#include "paddle/fluid/framework/ipu/ipu_utils.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ipu/ipu_utils.h"

namespace paddle {
namespace framework {
Expand All @@ -42,6 +42,8 @@ SymbolHandler GetHandler(const std::string &);
void MoveNodeInputs(ir::Node *node, ir::Node *new_node);
void MoveNodeOutputs(ir::Node *node, ir::Node *new_node);
void ConnectNodes(ir::Node *first_node, ir::Node *next_node);
void CopyOpAttr(std::string attr_name, OpDesc *op, OpDesc *new_op,
bool override = false);

int ConvertDataType(int);

Expand Down
54 changes: 54 additions & 0 deletions paddle/fluid/framework/ipu/popart_canonicalization/math_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -106,9 +106,63 @@ ir::Node *pow_handler(ir::Graph *graph, ir::Node *node) {
return node_pow;
}

ir::Node *mul_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("MatMul");

std::vector<std::string> inputs;
inputs.push_back(op->Input("X").front());
inputs.push_back(op->Input("Y").front());
op_desc->SetInput("__inputs__", inputs);
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);

op_desc->Flush();
return graph->CreateOpNode(op_desc.get());
}

ir::Node *sum_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("Sum");

op_desc->SetInput("__inputs__", op->Input("X"));
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);

op_desc->Flush();
return graph->CreateOpNode(op_desc.get());
}

ir::Node *softmax_handler(ir::Graph *graph, ir::Node *node) {
auto *op = node->Op();
auto op_desc = std::make_unique<framework::OpDesc>();
op_desc->SetType("Softmax");

std::vector<std::string> inputs;
inputs.push_back(op->Input("X").front());
op_desc->SetInput("__inputs__", inputs);
std::vector<std::string> outputs;
outputs.push_back(op->Output("Out").front());
op_desc->SetOutput("__outputs__", outputs);

auto axis_ = BOOST_GET_CONST(int, op->GetAttr("axis"));
auto axis = int64_t{axis_};
op_desc->SetAttr("axis", axis);

op_desc->Flush();
return graph->CreateOpNode(op_desc.get());
}

REGISTER_HANDLER(elementwise_add, elementwise_add_handler);
REGISTER_HANDLER(reduce_mean, reduce_mean_handler);
REGISTER_HANDLER(pow, pow_handler);
REGISTER_HANDLER(mul, mul_handler);
REGISTER_HANDLER(sum, sum_handler);
REGISTER_HANDLER(softmax, softmax_handler);

} // namespace
} // namespace ipu
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const {
if (new_node->outputs.empty()) {
ipu::MoveNodeOutputs(node, new_node);
}
if (!new_node->Op()->HasAttr("ipu_index")) {
ipu::CopyOpAttr("ipu_index", node->Op(), new_node->Op());
}
graph->RemoveNode(node);
}
}
Expand Down
73 changes: 73 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_mul_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestMul(unittest.TestCase):
def _test_mul(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.fluid.layers.fill_constant(
shape=[10, 10], dtype='float32', value=0.5)
b = paddle.fluid.layers.fill_constant(
shape=[10, 20], dtype='float32', value=3.1415926)
out = paddle.fluid.layers.mul(a, b)

if run_ipu:
place = paddle.IPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = []
fetch_list = [out.name]
ipu_build_strategy = compiler.get_ipu_build_strategy()
ipu_build_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_build_strategy=ipu_build_strategy).compile(
feed_list, fetch_list)
else:
program = main_prog

result = exe.run(program, feed={}, fetch_list=[out])
return result[0]

def test_mul(self):
ipu_res = self._test_mul(True)
cpu_res = self._test_mul(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_softmax_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSoftmax(unittest.TestCase):
def _test_softmax(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

np_a = np.random.rand(1, 3, 10, 10).astype(np.float32)
with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.static.data(
name='a', shape=[1, 3, 10, 10], dtype='float32')
b = paddle.fluid.layers.softmax(a)
c = paddle.fluid.layers.softmax(b, axis=0)
out = paddle.fluid.layers.softmax(c, axis=1)
out = b

if run_ipu:
place = paddle.IPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = [a.name]
fetch_list = [out.name]
ipu_build_strategy = compiler.get_ipu_build_strategy()
ipu_build_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_build_strategy=ipu_build_strategy).compile(
feed_list, fetch_list)
else:
program = main_prog

result = exe.run(program, feed={'a': np_a}, fetch_list=[out])
return result[0]

def test_softmax(self):
ipu_res = self._test_softmax(True)
cpu_res = self._test_softmax(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()
75 changes: 75 additions & 0 deletions python/paddle/fluid/tests/unittests/ipu/test_sum_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function

import numpy as np
import unittest
import paddle
import paddle.fluid
import paddle.fluid.compiler as compiler

paddle.enable_static()
SEED = 2021


@unittest.skipIf(not paddle.is_compiled_with_ipu(),
"core is not compiled with IPU")
class TestSum(unittest.TestCase):
def _test_sum(self, run_ipu=True):
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()
main_prog.random_seed = SEED
startup_prog.random_seed = SEED
np.random.seed(SEED)

with paddle.static.program_guard(main_prog, startup_prog):
a = paddle.fluid.layers.fill_constant(
shape=[10, 10], dtype='float32', value=0.5)
b = paddle.fluid.layers.fill_constant(
shape=[10, 10], dtype='float32', value=1.0)
c = paddle.fluid.layers.fill_constant(
shape=[10, 10], dtype='float32', value=3.1415926)
out = paddle.fluid.layers.sum([a, b, c])

if run_ipu:
place = paddle.IPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog)

if run_ipu:
feed_list = []
fetch_list = [out.name]
ipu_build_strategy = compiler.get_ipu_build_strategy()
ipu_build_strategy.is_training = False
program = compiler.IpuCompiler(
main_prog, ipu_build_strategy=ipu_build_strategy).compile(
feed_list, fetch_list)
else:
program = main_prog

result = exe.run(program, feed={}, fetch_list=[out])
return result[0]

def test_sum(self):
ipu_res = self._test_sum(True)
cpu_res = self._test_sum(False)

self.assertTrue(np.allclose(ipu_res, cpu_res, atol=1e-4))


if __name__ == "__main__":
unittest.main()

0 comments on commit fe0212c

Please sign in to comment.