modified error info for maxout op

update-install-command
jerrywgz 7 years ago
parent bc7503c85e
commit 85fe65ae61

@ -71,10 +71,9 @@ class MaxOutOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input tensor of MaxoutOp"
"should not be null.");
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output tensor of MaxoutOp should not be null.");
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
// check groups > 1
@ -90,9 +89,10 @@ class MaxOutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input tensor must not be null.");
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxOutOpGrad must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output gradient tensor should not be null.");
"Output(Grad@X) of MaxOutOpGrad should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};

Loading…
Cancel
Save