mobile_baidu
wwhu 7 years ago
parent c8c4b6e427
commit b3a86b6dbb

@ -27,7 +27,7 @@ class ClipByNormOp : public framework::OperatorWithKernel {
"Input(X) of ClipByNormOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ClipByNormOp should not be null.");
auto max_norm = Attr<float>("max_norm");
auto max_norm = ctx->Attrs().Get<float>("max_norm");
PADDLE_ENFORCE_GT(max_norm, 0, "max_norm should be greater than 0.");
auto x_dims = ctx->GetInputDim("X");
ctx->SetOutputDim("Out", x_dims);
@ -35,7 +35,6 @@ class ClipByNormOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ClipByNormOpMaker(framework::OpProto* proto,
@ -46,7 +45,7 @@ class ClipByNormOpMaker : public framework::OpProtoAndCheckerMaker {
"The number of dimensions must be between [1, 9].");
AddOutput("Out",
"(Tensor) The output of clip_by_norm op with shape as input(X)");
AddAttr<AttrType>("max_norm", "(float) The maximum norm value.");
AddAttr<float>("max_norm", "(float) The maximum norm value.");
AddComment(R"DOC(
ClipByNorm operator limits the L2 norm of the input 'X' within 'max_norm'.
If the L2 norm of 'X' is less than or equal to 'max_norm', 'Out' will be
@ -66,6 +65,6 @@ where norm('X') represents the L2 norm of 'X'.
namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(clip_by_norm, ops::ClipByNormOp,
ops::ClipByNormOpMaker<float>);
ops::ClipByNormOpMaker);
REGISTER_OP_CPU_KERNEL(
clip_by_norm, ops::ClipByNormKernel<paddle::platform::CPUPlace, float>);

Loading…
Cancel
Save