add check for input channels and Attr(groups), test=develop (#21095)

custom_op_abi
Zhang Ting 6 years ago committed by Aurelius84
parent dcf371b685
commit e0285eae64

@ -82,6 +82,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
"ValueError: The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis);
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;

Loading…
Cancel
Save