You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
Paddle/paddle/fluid/operators/maxout_op.cc

127 lines
4.9 KiB

/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/maxout_op.h"
#include <vector>
namespace paddle {
namespace operators {
using framework::Tensor;
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"A 4-D Tensor with data type of float32 or float64. "
"The data format is NCHW or NHWC. Where N is "
"batch size, C is the number of channels, "
"H and W is the height and width of "
"feature. ");
AddOutput("Out",
"A 4-D Tensor with same data type and data format "
"with input Tensor. ");
AddAttr<int>(
"groups",
"Specifies how many groups the input tensor will be split into "
"at the channel dimension. And the number of output channel is "
"the number of channels divided by groups. ");
AddAttr<int>(
"axis",
"Specifies the index of channel dimension where maxout will "
"be performed. It should be 1 when data format is NCHW, -1 or 3 "
"when data format is NHWC. "
"Default: 1. ")
.SetDefault(1);
AddComment(R"DOC(
MaxOut Operator.
Assumed the input shape is (N, Ci, H, W).
The output shape is (N, Co, H, W).
Then $Co = Ci / groups$ and the operator formula is as follows:
$$ y_{si+j} = \max_{k} x_{gsi + sk + j} $$
$$ g = groups $$
$$ s = \\frac{input.size}{num\\_channels} $$
$$ 0 \\le i < \\frac{num\\_channels}{groups} $$
$$ 0 \\le j < s $$
$$ 0 \\le k < groups $$
Please refer to Paper:
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
- Multi-digit Number Recognition from Street View \
Imagery using Deep Convolutional Neural Networks: \
https://arxiv.org/pdf/1312.6082v4.pdf
)DOC");
}
};
class MaxOutOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
"Input(X) of MaxoutOpshould not be null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Out"), true,
"Output(Out) of MaxoutOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
int groups = ctx->Attrs().Get<int>("groups");
int axis = ctx->Attrs().Get<int>("axis");
// check groups > 1
PADDLE_ENFORCE_GT(groups, 1,
"Attr(groups) of Op(maxout) should be larger than 1.");
PADDLE_ENFORCE_EQ(
in_x_dims[axis] % groups, 0,
"ValueError: The number of input channels for Op(maxout) "
"should be divisible by Attr(groups). But received: the "
"input's channels is [%d], the shape of input is [%s], "
"the Attr(groups) is [%d], the Attr(axis) is [%d]. The "
"error may come from wrong Attr(groups) or Attr(axis) setting.",
in_x_dims[axis], in_x_dims, groups, axis);
std::vector<int64_t> output_shape(
{in_x_dims[0], in_x_dims[1], in_x_dims[2], in_x_dims[3]});
output_shape[axis] = in_x_dims[axis] / groups;
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
}
};
class MaxOutOpGrad : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of MaxOutOpGrad must not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(Grad@X) of MaxOutOpGrad should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
GradMaker for dygraph (#19706) * refactor dygraph,test=develop * fix failed unittest,test=develop * polish code,test=develop * check windows ci error,test=develop try to fix windows ci error by np.allclose,test=develop * polish vlog and profiler, test=develop * try to fix preceding ops order,test=develop * test transformer in windows ci, test=develop * use python c-api to speed up tracer.trace,test=develop * test=develop, fix docker with paddle nccl problem * test=develop, add ut for debug string and gradient_accumulator * test=develop, add tests for layer/gradient_accumulator/prepared_op * test=develop, fix complie error for test_prepared_op * test=develop, add more ut for dygraph * test=develop, create API.spec for dygraph api change * optimize grad maker; test=develop * optimize grad maker * test * grad make optim; test=develop * fix unittest bugs; test=develop * add dygraph grad op maker and split_op * grad op maker refactor; test=develop * add dygraph grad maker; test=develop * fix op deformable_conv_v1_op bug; test=develop * fix deformable_conv prroi pool bugs; * fix new op grad op maker bug; test=develop * fix split by ref bug; test=develop * fix dygraph auto prune bug; test=develop * fix test_trace bug; test=develop * fix fused emb seq pool bug; test=develop * remove useless code in op_desc file; test=develop * remove useless code, StrVarBaseNode; test=develop * fix review issues; test=develop * fix rank_loss grad maker; test=develop * remove flag in VarBase; test=develop * fix distributed_notify_op compile bug ; test=develop * fix reshape op double grad; test=develop * fix expand as op; test=develop * add impertive type_defs.h for demo_train; test=develop * fix inference lib cmake; test=develop * fix inference lib; test=develop * fix infernce_lib; test=develop * fix inference cmake; test=develop * fix inference lib; test=develop * fix inference lib; test=develop * remove condition dygraph grad maker, modify local name; test=develop * fix split grad maker bug; test=develop * fix pyramid_op bug; test=develop * change travis time out limit; test=develop * restore travis; test=develop * change timeout limit; test=develop
5 years ago
REGISTER_OPERATOR(
maxout, ops::MaxOutOp, ops::MaxOutOpMaker,
paddle::framework::DefaultGradOpMaker<paddle::framework::OpDesc, true>,
paddle::framework::DefaultGradOpMaker<paddle::imperative::OpBase, true>);
REGISTER_OPERATOR(maxout_grad, ops::MaxOutOpGrad);
REGISTER_OP_CPU_KERNEL(
maxout, ops::MaxOutKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
maxout_grad,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MaxOutGradKernel<paddle::platform::CPUDeviceContext, double>);