Enable square operator for the nGraph Bridge. (#17551)

test=develop
fix_ema
Krzysztof Binias 7 years ago committed by Tao Luo
parent ff5fdc0b67
commit 43d15b9d96

@ -37,6 +37,16 @@ void BuildReluGradNode(
platform::SetOutputNode(op, "X@GRAD", relu_grad, ngb_node_map);
}
void BuildSquareNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
std::unordered_map<std::string, std::shared_ptr<ngraph::Node>>>
ngb_node_map) {
auto input = platform::GetInputNode(op, "X", ngb_node_map);
auto out = input * input;
platform::SetOutputNode(op, "Out", out, ngb_node_map);
}
void BuildTanhGradNode(
const std::shared_ptr<framework::OperatorBase>& op,
std::shared_ptr<
@ -55,4 +65,5 @@ void BuildTanhGradNode(
} // namespace paddle
REGISTER_NG_OP(relu_grad, BuildReluGradNode);
REGISTER_NG_OP(square, BuildSquareNode);
REGISTER_NG_OP(tanh_grad, BuildTanhGradNode);

@ -18,7 +18,7 @@ import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest
from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestRelu, TestTanh
from paddle.fluid.tests.unittests.test_activation_op import TestAbs, TestSigmoid, TestSquare, TestRelu, TestTanh
class TestNGRAPHReluDim4(TestRelu):

Loading…
Cancel
Save