|
|
|
@ -24,40 +24,48 @@ class PReluOp : public framework::OperatorWithKernel {
|
|
|
|
|
: OperatorWithKernel(type, inputs, outputs, attrs) {}
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
std::string mode = ctx->Attrs().Get<std::string>("mode");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("Alpha"), "Input", "Alpha", "prelu");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "prelu");
|
|
|
|
|
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of PreluOp should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
|
|
|
|
|
"Input(Alpha) of PreluOp should not be null");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of PreluOp should not be null");
|
|
|
|
|
std::string mode = ctx->Attrs().Get<std::string>("mode");
|
|
|
|
|
if (mode == "all") {
|
|
|
|
|
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
|
|
|
|
|
"For mode 'all', size of weight Alpha must be one.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
product(ctx->GetInputDim("Alpha")), 1,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"For mode 'all', size of weight Alpha must be one."));
|
|
|
|
|
} else if (mode == "channel") {
|
|
|
|
|
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == x_dim[1],
|
|
|
|
|
"For channel-wise mode, size of weight Alpha must be "
|
|
|
|
|
"equal to the number of channels, should be %d",
|
|
|
|
|
x_dim[1]);
|
|
|
|
|
PADDLE_ENFORCE_EQ(product(ctx->GetInputDim("Alpha")), x_dim[1],
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"For mode 'channel', size of weight Alpha must be "
|
|
|
|
|
"equal to the number of channels of input(x). But "
|
|
|
|
|
"recevied alpha's size: %d, x_dim[1]: %d",
|
|
|
|
|
product(ctx->GetInputDim("Alpha")), x_dim[1]));
|
|
|
|
|
} else if (mode == "element") {
|
|
|
|
|
auto alpha_dim = ctx->GetInputDim("Alpha");
|
|
|
|
|
auto alpha_rank = alpha_dim.size();
|
|
|
|
|
auto x_rank = x_dim.size();
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
alpha_rank, x_rank,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"For mode 'element', rank of weight Alpha must be ",
|
|
|
|
|
"equal to the rank of input(x). But recevied alpha's rank: %d, "
|
|
|
|
|
"x's rank: %d.",
|
|
|
|
|
alpha_rank, x_rank));
|
|
|
|
|
size_t x_product = 1;
|
|
|
|
|
size_t alpha_product = 1;
|
|
|
|
|
PADDLE_ENFORCE_EQ(alpha_rank, x_rank,
|
|
|
|
|
"For element-wise mode, rank of weight Alpha must be ",
|
|
|
|
|
"equal to the rank of input.");
|
|
|
|
|
for (int64_t i = x_rank - 1; i > 0; i--) {
|
|
|
|
|
x_product *= x_dim[i];
|
|
|
|
|
alpha_product *= alpha_dim[i];
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE_EQ(x_product, alpha_product,
|
|
|
|
|
"For element-wise mode, size of weight Alpha must be "
|
|
|
|
|
"equal to the number of input.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
alpha_product, x_product,
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"For mode 'element', the size of weight Alpha must be "
|
|
|
|
|
"equal to the size of input(x). But recevied alpha's size: %d, "
|
|
|
|
|
"x's size: %d.",
|
|
|
|
|
alpha_product, x_product));
|
|
|
|
|
} else {
|
|
|
|
|
PADDLE_THROW("Unkown mode %s", mode);
|
|
|
|
|
}
|
|
|
|
@ -108,9 +116,10 @@ class PReluGradOp : public framework::OperatorWithKernel {
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
|
|
|
|
|
|
void InferShape(framework::InferShapeContext *ctx) const override {
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
|
|
|
|
|
"Input(Out@GRAD) should not be null");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "prelu");
|
|
|
|
|
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
|
|
|
|
"Out@GRAD", "prelu");
|
|
|
|
|
|
|
|
|
|
auto x_grad_name = framework::GradVarName("X");
|
|
|
|
|
auto alpha_grad_name = framework::GradVarName("Alpha");
|
|
|
|
|
|
|
|
|
|