|
|
|
@ -39,15 +39,14 @@ template <typename AttrType>
|
|
|
|
|
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
ClipByNormOpMaker(framework::OpProto* proto,
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
framework::OpAttrChecker* op_checker)
|
|
|
|
|
: OpProtoAndCheckerMaker(proto, op_checker) {
|
|
|
|
|
AddInput("X",
|
|
|
|
|
"(Tensor)The input of clip_by_norm op."
|
|
|
|
|
"(Tensor) The input of clip_by_norm op."
|
|
|
|
|
"The number of dimensions must be between [1, 9].");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor)The output of clip_by_norm op with shape as input(X)");
|
|
|
|
|
AddAttr<AttrType>(
|
|
|
|
|
"max_norm", "(float)The maximum norm value.");
|
|
|
|
|
"(Tensor) The output of clip_by_norm op with shape as input(X)");
|
|
|
|
|
AddAttr<AttrType>("max_norm", "(float)The maximum norm value.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
|
|
|
|
|
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
|
|
|
|
@ -62,29 +61,11 @@ where norm('X') represents the L2 norm of 'X'.
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class ClipByNormOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm,
|
|
|
|
|
ops::ClipByNormOp,
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
|
|
|
|
|
ops::ClipByNormOpMaker<float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(clip_by_norm,
|
|
|
|
|
ops::ClipByNormKernel
|
|
|
|
|
<paddle::platform::CPUPlace, float>);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|