|
|
|
@ -23,14 +23,21 @@ class ClipOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of ClipOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of ClipOp should not be null.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(X) of ClipOp should not be null. Please check "
|
|
|
|
|
"if it is created correctly."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Output(Out) of ClipOp should not be null. Please "
|
|
|
|
|
"check if it is created correctly."));
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
auto max = ctx->Attrs().Get<float>("max");
|
|
|
|
|
auto min = ctx->Attrs().Get<float>("min");
|
|
|
|
|
PADDLE_ENFORCE_LT(min, max, "max should be greater than min.");
|
|
|
|
|
PADDLE_ENFORCE_LT(min, max, platform::errors::InvalidArgument(
|
|
|
|
|
"Max of ClipOp should be greater than min. "
|
|
|
|
|
"Received max is %f, received min is %f.",
|
|
|
|
|
max, min));
|
|
|
|
|
ctx->SetOutputDim("Out", x_dims);
|
|
|
|
|
ctx->ShareLoD("X", /*->*/ "Out");
|
|
|
|
|
}
|
|
|
|
@ -52,7 +59,7 @@ class ClipOpMaker : public framework::OpProtoAndCheckerMaker {
|
|
|
|
|
AddComment(R"DOC(
|
|
|
|
|
Clip Operator.
|
|
|
|
|
|
|
|
|
|
The clip operator limits the value of given input within an interval [min, max],
|
|
|
|
|
The clip operator limits the value of given input within an interval [min, max],
|
|
|
|
|
just as the following equation,
|
|
|
|
|
|
|
|
|
|
$$
|
|
|
|
@ -68,9 +75,14 @@ class ClipOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext* ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
ctx->HasInput("X"), true,
|
|
|
|
|
platform::errors::InvalidArgument("Input(X) should not be null. Please "
|
|
|
|
|
"check if it is created correctly."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Out")), true,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input(Out@GRAD) should not be null. Please check if "
|
|
|
|
|
"it is created correctly."));
|
|
|
|
|
auto x_dims = ctx->GetInputDim("X");
|
|
|
|
|
if (ctx->HasOutput(framework::GradVarName("X"))) {
|
|
|
|
|
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
|
|
|
|
|