Merge pull request #14805 from mozga-intel/mozga-intel/element_wise_operator_ngraph
Enable element_wise_add operator for a ngraph enginerecover_files
commit
3759c1db8c
@ -0,0 +1,87 @@
|
||||
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "paddle/fluid/operators/ngraph/ops/elementwise_node.h"
|
||||
#include "paddle/fluid/platform/ngraph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace ngraphs {
|
||||
|
||||
void BuildElementwiseAddNode(
|
||||
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
ngb_node_map) {
|
||||
BuildElementwiseBinaryNode<ngraph::op::Add>(op, ngb_node_map);
|
||||
}
|
||||
|
||||
void BuildElementwiseAddGradNode(
|
||||
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
ngb_node_map) {
|
||||
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
||||
int axis = op_attrs.Get<int>("axis");
|
||||
|
||||
auto dout = paddle::platform::GetInputNode(op, "Out@GRAD", ngb_node_map);
|
||||
auto y = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
|
||||
auto dout_shape = dout->get_shape();
|
||||
auto y_shape = y->get_shape();
|
||||
|
||||
if (dout_shape == y_shape) {
|
||||
paddle::platform::SetOutputNode(op, "X@GRAD", dout, ngb_node_map);
|
||||
paddle::platform::SetOutputNode(op, "Y@GRAD", dout, ngb_node_map);
|
||||
} else {
|
||||
axis = (axis == -1 ? dout_shape.size() - y_shape.size() : axis);
|
||||
paddle::platform::TrimTrailingSingularDims(&y_shape);
|
||||
axis = (y_shape.size() == 0 ? dout_shape.size() : axis);
|
||||
|
||||
int pre, n, post;
|
||||
paddle::platform::GetMidDims(dout_shape, y_shape, axis, &pre, &n, &post);
|
||||
|
||||
ngraph::Shape lhs_shape{};
|
||||
lhs_shape.push_back(pre);
|
||||
lhs_shape.push_back(n);
|
||||
if (post != 1) {
|
||||
lhs_shape.push_back(post);
|
||||
}
|
||||
|
||||
std::vector<size_t> lhs_order(dout_shape.size());
|
||||
std::iota(std::begin(lhs_order), std::end(lhs_order), 0);
|
||||
auto dout_reshape = std::make_shared<ngraph::op::Reshape>(
|
||||
dout, ngraph::AxisVector(lhs_order), lhs_shape);
|
||||
|
||||
ngraph::AxisSet axis_set{0};
|
||||
if (post != 1) {
|
||||
axis_set.insert(2);
|
||||
}
|
||||
|
||||
auto dout_sum = std::make_shared<ngraph::op::Sum>(dout_reshape, axis_set);
|
||||
auto dy = std::make_shared<ngraph::op::Reshape>(
|
||||
dout_sum, ngraph::AxisVector{0}, y->get_shape());
|
||||
|
||||
paddle::platform::SetOutputNode(op, "X@GRAD", dout, ngb_node_map);
|
||||
paddle::platform::SetOutputNode(op, "Y@GRAD", dy, ngb_node_map);
|
||||
}
|
||||
}
|
||||
} // namespace ngraphs
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,76 @@
|
||||
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "paddle/fluid/platform/ngraph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace ngraphs {
|
||||
|
||||
ngraph::NodeVector ElementwiseBinaryNodePrepare(
|
||||
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
ngb_node_map) {
|
||||
auto op_attrs = paddle::framework::AttrReader(op->Attrs());
|
||||
int axis = op_attrs.Get<int>("axis");
|
||||
auto lhs = paddle::platform::GetInputNode(op, "X", ngb_node_map);
|
||||
auto rhs = paddle::platform::GetInputNode(op, "Y", ngb_node_map);
|
||||
|
||||
auto lhs_shape = lhs->get_shape();
|
||||
auto rhs_shape = rhs->get_shape();
|
||||
|
||||
PADDLE_ENFORCE_GE(lhs_shape.size(), rhs_shape.size(),
|
||||
"Rank of first input must >= rank of second input.");
|
||||
if (lhs_shape == rhs_shape) {
|
||||
return ngraph::NodeVector{lhs, rhs};
|
||||
}
|
||||
axis = (axis == -1 ? lhs_shape.size() - rhs_shape.size() : axis);
|
||||
PADDLE_ENFORCE(axis >= 0 && axis < (int)(lhs_shape.size()),
|
||||
"Axis should be in range [0, lhs_shape)");
|
||||
paddle::platform::TrimTrailingSingularDims(&rhs_shape);
|
||||
axis = (rhs_shape.size() == 0) ? lhs_shape.size() : axis;
|
||||
|
||||
int pre, n, post;
|
||||
paddle::platform::GetMidDims(lhs_shape, rhs_shape, axis, &pre, &n, &post);
|
||||
|
||||
ngraph::Shape l_shape{};
|
||||
l_shape.push_back(pre);
|
||||
l_shape.push_back(n);
|
||||
l_shape.push_back(post);
|
||||
|
||||
std::vector<size_t> rhs_order(rhs->get_shape().size());
|
||||
std::iota(std::begin(rhs_order), std::end(rhs_order), 0);
|
||||
ngraph::Shape r_shape{};
|
||||
r_shape.push_back(n);
|
||||
auto rhs_reshape = std::make_shared<ngraph::op::Reshape>(
|
||||
rhs, ngraph::AxisVector(rhs_order), r_shape);
|
||||
auto rhs_bcast = std::make_shared<ngraph::op::Broadcast>(
|
||||
rhs_reshape, l_shape, ngraph::AxisSet{0, 2});
|
||||
std::vector<size_t> bcast_order(rhs_bcast->get_shape().size());
|
||||
std::iota(std::begin(bcast_order), std::end(bcast_order), 0);
|
||||
std::shared_ptr<ngraph::Node> rhs_bcast_reshape =
|
||||
std::make_shared<ngraph::op::Reshape>(
|
||||
rhs_bcast, ngraph::AxisVector(bcast_order), lhs_shape);
|
||||
return ngraph::NodeVector{lhs, rhs_bcast_reshape};
|
||||
}
|
||||
} // namespace ngraphs
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,63 @@
|
||||
/*Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "ngraph/ngraph.hpp"
|
||||
#include "paddle/fluid/operators/ngraph/ops/elementwise_binary_prepare_node.h"
|
||||
#include "paddle/fluid/platform/ngraph_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace ngraphs {
|
||||
|
||||
template <typename T>
|
||||
void BuildElementwiseBinaryNode(
|
||||
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
ngb_node_map) {
|
||||
auto nodes = ElementwiseBinaryNodePrepare(op, ngb_node_map);
|
||||
std::shared_ptr<ngraph::Node>& x = nodes.at(0);
|
||||
std::shared_ptr<ngraph::Node>& y = nodes.at(1);
|
||||
|
||||
if (x->get_element_type() != y->get_element_type()) {
|
||||
y = std::make_shared<ngraph::op::Convert>(y, x->get_element_type());
|
||||
}
|
||||
auto out = std::make_shared<T>(x, y);
|
||||
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void BuildElementwiseCompareNode(
|
||||
const std::shared_ptr<paddle::framework::OperatorBase>& op,
|
||||
std::shared_ptr<
|
||||
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
|
||||
ngb_node_map) {
|
||||
auto nodes = ElementwiseBinaryNodePrepare(op, ngb_node_map);
|
||||
std::shared_ptr<ngraph::Node>& x = nodes.at(0);
|
||||
std::shared_ptr<ngraph::Node>& y = nodes.at(1);
|
||||
|
||||
if (x->get_element_type() != y->get_element_type()) {
|
||||
x = std::make_shared<ngraph::op::Convert>(x, ngraph::element::f64);
|
||||
y = std::make_shared<ngraph::op::Convert>(y, ngraph::element::f64);
|
||||
}
|
||||
auto out = std::make_shared<T>(x, y);
|
||||
paddle::platform::SetOutputNode(op, "Out", out, ngb_node_map);
|
||||
}
|
||||
} // namespace ngraphs
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,87 @@
|
||||
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import unittest
|
||||
from paddle.fluid.tests.unittests.test_elementwise_add_op import *
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp(TestElementwiseAddOp):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_scalar(TestElementwiseAddOp_scalar):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_scalar, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_scalar2(TestElementwiseAddOp_scalar2):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_scalar2, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_Vector(TestElementwiseAddOp_Vector):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_Vector, self).init_input_output()
|
||||
|
||||
|
||||
class TesNGRAPHtElementwiseAddOp_broadcast_0(TestElementwiseAddOp_broadcast_0):
|
||||
def init_input_output(self):
|
||||
super(TesNGRAPHtElementwiseAddOp_broadcast_0, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_broadcast_1(TestElementwiseAddOp_broadcast_1):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_broadcast_1, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_broadcast_2(TestElementwiseAddOp_broadcast_2):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_broadcast_2, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_broadcast_3(TestElementwiseAddOp_broadcast_3):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_broadcast_3, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_broadcast_4(TestElementwiseAddOp_broadcast_4):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_broadcast_4, self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_rowwise_add_0(
|
||||
TestElementwiseAddOp_rowwise_add_0):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_rowwise_add_0,
|
||||
self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_rowwise_add_1(
|
||||
TestElementwiseAddOp_rowwise_add_1):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_rowwise_add_1,
|
||||
self).init_input_output()
|
||||
|
||||
|
||||
class TestNGRAPHElementwiseAddOp_channelwise_add(
|
||||
TestElementwiseAddOp_channelwise_add):
|
||||
def init_input_output(self):
|
||||
super(TestNGRAPHElementwiseAddOp_channelwise_add,
|
||||
self).init_input_output()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue