|
|
|
@ -82,6 +82,14 @@ class MaxOutOp : public framework::OperatorWithKernel {
|
|
|
|
|
// check groups > 1
|
|
|
|
|
PADDLE_ENFORCE_GT(groups, 1,
|
|
|
|
|
"Attr(groups) of Op(maxout) should be larger than 1.");
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
in_x_dims[axis] % groups, 0,
|
|
|
|
|
"ValueError: The number of input channels for Op(maxout) "
|
|
|
|
|
"should be divisible by Attr(groups). But received: the "
|
|
|
|
|
"input's channels is [%d], the shape of input is [%s], "
|
|
|
|
|
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
|
|
|
|
|
"error may come from wrong Attr(groups) or Attr(axis) setting.",
|
|
|
|
|
in_x_dims[axis], in_x_dims, groups, axis);
|
|
|
|
|
std::vector<int64_t> output_shape(
|
|
|
|
|
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
|
|
|
|
|
output_shape[axis] = in_x_dims[axis] / groups;
|
|
|
|
|