|
|
|
@ -26,10 +26,13 @@ class PReluOp : public framework::OperatorWithKernel {
|
|
|
|
|
std::string mode = ctx->Attrs().Get<std::string>("mode");
|
|
|
|
|
|
|
|
|
|
auto x_dim = ctx->GetInputDim("X");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Alpha"), "Input(Alpha) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("X"),
|
|
|
|
|
"Input(X) of PreluOp should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasInput("Alpha"),
|
|
|
|
|
"Input(Alpha) of PreluOp should not be null");
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"), "Output(Out) should not be null");
|
|
|
|
|
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
|
|
|
|
"Output(Out) of PreluOp should not be null");
|
|
|
|
|
if (mode == "all") {
|
|
|
|
|
PADDLE_ENFORCE(product(ctx->GetInputDim("Alpha")) == 1,
|
|
|
|
|
"For mode 'all', size of weight Alpha must be one.");
|
|
|
|
|