|
|
|
@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
"Input(X) of ClipByNormOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ClipByNormOp should not be null.");
|
|
|
|
|
auto max_norm = Attr<float>("max_norm");
|
|
|
|
|
auto max_norm = ctx->Attrs().Get<float>("max_norm");
|
|
|
|
|
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
ctx->SetOutputDim("Out", x_dims);
|
|
|
|
@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename AttrType>
|
|
|
|
|
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
public:
|
|
|
|
|
ClipByNormOpMaker(framework::OpProto* proto,
|
|
|
|
@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
"The number of dimensions must be between [1, 9].");
|
|
|
|
|
AddOutput("Out",
|
|
|
|
|
"(Tensor) The output of clip_by_norm op with shape as input(X)");
|
|
|
|
|
AddAttr<AttrType>("max_norm", "(float) The maximum norm value.");
|
|
|
|
|
AddAttr<float>("max_norm", "(float) The maximum norm value.");
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
|
|
|
|
|
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
|
|
|
|
@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'.
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
|
|
|
|
|
ops::ClipByNormOpMaker<float>);
|
|
|
|
|
ops::ClipByNormOpMaker);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);
|
|
|
|
|