|
|
|
@ -13,6 +13,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/operators/reduce_op.h"
|
|
|
|
|
#include "paddle/operators/net_op.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -159,6 +160,66 @@ class ReduceMinOpMaker : public ReduceOpMaker {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NormOp : public NetOp {
|
|
|
|
|
public:
|
|
|
|
|
NormOp(const std::string &type, const framework::VariableNameMap &inputs,
|
|
|
|
|
const framework::VariableNameMap &outputs,
|
|
|
|
|
const framework::AttributeMap &attrs)
|
|
|
|
|
: NetOp(type, inputs, outputs, attrs) {
|
|
|
|
|
PADDLE_ENFORCE_NE(Input("X"), framework::kEmptyVarName,
|
|
|
|
|
"Input(X) of NormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("AbsOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(AbsOut) of NormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("PowOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(PowOut) of NormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("SumOut"), framework::kEmptyVarName,
|
|
|
|
|
"Output(SumOut) of NormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_NE(Output("Out"), framework::kEmptyVarName,
|
|
|
|
|
"Output(Out) of NormOp should not be null.");
|
|
|
|
|
auto dim = Attr<int>("dim");
|
|
|
|
|
auto keep_dim = Attr<bool>("keep_dim");
|
|
|
|
|
auto p = Attr<float>("p");
|
|
|
|
|
PADDLE_ENFORCE_GT(p, 0, "Order of the norm should be positive.");
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("abs", {{"X", {Input("X")}}},
|
|
|
|
|
{{"Y", {Output("AbsOut")}}}, {}));
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp("pow", {{"X", {Output("AbsOut")}}},
|
|
|
|
|
{{"Y", {Output("PowOut")}}},
|
|
|
|
|
{{"factor", p}}));
|
|
|
|
|
framework::AttributeMap sum_attr;
|
|
|
|
|
sum_attr["dim"] = dim;
|
|
|
|
|
sum_attr["keep_dim"] = keep_dim;
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"reduce_sum", {{"X", {Output("PowOut")}}},
|
|
|
|
|
{{"Out", {Output("SumOut")}}}, sum_attr));
|
|
|
|
|
AppendOp(framework::OpRegistry::CreateOp(
|
|
|
|
|
"pow", {{"X", {Output("SumOut")}}}, {{"Y", {Output("Out")}}},
|
|
|
|
|
{{"factor", static_cast<float>(1. / p)}}));
|
|
|
|
|
CompleteAddOp(false);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NormOpMaker : public ReduceOpMaker {
|
|
|
|
|
public:
|
|
|
|
|
NormOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
|
|
|
|
|
: ReduceOpMaker(proto, op_checker) {
|
|
|
|
|
AddOutput("AbsOut",
|
|
|
|
|
"(Tensor) The intermediate output of Norm operator, "
|
|
|
|
|
"saving the absolute value of the input tensor X.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("PowOut",
|
|
|
|
|
"(Tensor) The intermediate output of Norm operator, "
|
|
|
|
|
"saving the p-th power of the output tensor AbsOut.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddOutput("SumOut",
|
|
|
|
|
"(Tensor) the intermediate output of Norm operator, "
|
|
|
|
|
"saving the sum of PowOut reduced on the given dimension.")
|
|
|
|
|
.AsIntermediate();
|
|
|
|
|
AddAttr<float>("p", "(float, default 2) The order of Norm.").SetDefault(2);
|
|
|
|
|
SetComment("Norm", "vector p-norm");
|
|
|
|
|
AddComment(comment_);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -176,6 +237,8 @@ REGISTER_OP(reduce_max, ops::ReduceOp, ops::ReduceMaxOpMaker, reduce_max_grad,
|
|
|
|
|
REGISTER_OP(reduce_min, ops::ReduceOp, ops::ReduceMinOpMaker, reduce_min_grad,
|
|
|
|
|
ops::ReduceGradOp);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(norm, ops::NormOp, ops::NormOpMaker);
|
|
|
|
|
|
|
|
|
|
#define REGISTER_REDUCE_CPU_KERNEL(reduce_type, functor, grad_functor) \
|
|
|
|
|
REGISTER_OP_CPU_KERNEL( \
|
|
|
|
|
reduce_type, \
|
|
|
|
|