From 4a428c8fbbc5398912727107124484e563707c9c Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sat, 11 Nov 2017 18:13:14 +0800 Subject: [PATCH 01/15] this for maxout op new add --- paddle/operators/math/maxouting.cc | 117 +++++++++++++ paddle/operators/math/maxouting.cu | 161 ++++++++++++++++++ paddle/operators/math/maxouting.h | 99 +++++++++++ paddle/operators/maxout_op.cc | 115 +++++++++++++ paddle/operators/maxout_op.cu | 23 +++ paddle/operators/maxout_op.h | 77 +++++++++ .../v2/framework/tests/test_maxout_op.py | 52 ++++++ 7 files changed, 644 insertions(+) create mode 100644 paddle/operators/math/maxouting.cc create mode 100644 paddle/operators/math/maxouting.cu create mode 100644 paddle/operators/math/maxouting.h create mode 100644 paddle/operators/maxout_op.cc create mode 100644 paddle/operators/maxout_op.cu create mode 100644 paddle/operators/maxout_op.h create mode 100644 python/paddle/v2/framework/tests/test_maxout_op.py diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc new file mode 100644 index 0000000000..f01fa18391 --- /dev/null +++ b/paddle/operators/math/maxouting.cc @@ -0,0 +1,117 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { +namespace math { + +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, MaxOutProcess maxout_process) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels/groups; + + int fea_size = input_height * input_width; + int c_size = fea_size * output_channels; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + int new_bindex = c_size * i; + for (int c = 0; c < output_channels; ++c) { + int new_cindex = fea_size * c; + for (int f = 0; f < fea_size; f++) { + T ele = maxout_process.initial(); + for (int ph = 0; ph < groups; ++ph) { + maxout_process.compute(ele, + input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); + } + maxout_process.finalize(ele, (static_cast(groups))); + output_data[(new_bindex+new_cindex+f)] = ele; + } + } + } + } +}; + + + +template +class MaxOutGradFunctor { +public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups, int num_channels) { + const int batch_size = input.dims()[0]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels / groups; + + int fea_size = input_height * input_width; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + for (int i = 0; i < batch_size; i++) { + int blen = fea_size * output_channels * i; + for (int c = 0; c < output_channels; ++c) { + int clen = fea_size * c; + for (int f = 0; f < fea_size; f++) { + int input_idx = 0; + bool stop = false; + int output_idx = blen + clen + f; + for (int g = 0; g < groups && !stop; g++) { + input_idx = (blen + clen) * groups + fea_size * g + f; + input_grad_data[input_idx] = 0; + if (input_data[input_idx] == output_data[output_idx]) { + input_grad_data[input_idx] += output_grad_data[output_idx]; + stop = true; + } else { + input_grad_data[input_idx] = 0; + } + } + } + } + } + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; +template class MaxOutFunctor, float>; +template class MaxOutFunctor, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu new file mode 100644 index 0000000000..b1c0dd8fd4 --- /dev/null +++ b/paddle/operators/math/maxouting.cu @@ -0,0 +1,161 @@ +/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/operators/math/maxouting.h" +#include "paddle/platform/cuda_helper.h" + +namespace paddle { +namespace operators { +namespace math { + +template +__global__ void KernelMaxOut(const int nthreads, const T* input_data, + T* output_data, const int channels, + const int input_height, const int input_width, + int groups, MaxOutProcess maxout_process) { + int size = input_height * input_width * channels / groups; + int featLen = input_height * input_width; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int batch_idx = index / size; + int i = index % size; + int channel_idx = i / featLen; + int feat_idx = i % featLen; + int data_idx = + (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + T ele = maxout_process.initial(); + for (int g = 0; g < groups; g++) { + maxout_process.compute(ele, input_data[data_idx + g * featLen]); + } + maxout_process.finalize(ele, (static_cast(groups))); + output_data[index] = ele; + } +} +template +__global__ void KernelMaxoutGrad( + const int nthreads, const T* input_data, const T* output_data, + const T* output_grad, T* input_grad, const int channels, + const int input_height, const int input_width, int groups) { + int size = input_height * input_width * channels / groups; + int featLen = input_height * input_width; + for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; + index += blockDim.x * gridDim.x) { + int batch_idx = index / size; + int i = index % size; + int channel_idx = i / featLen; + int feat_idx = i % featLen; + int data_idx = + (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + int maxIndex = -1; + bool stop = false; + for (int g = 0; g < groups && !stop; g++) { + if (input_data[data_idx + g * featLen] == output_data[index]) { + maxIndex = data_idx + g * featLen; + stop = true; + } + } + if (maxIndex != -1) { + // atomic add + platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); + } + } +} +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, + MaxOutProcess maxout_process) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = num_channels / groups; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + T* output_data = output.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxOut< + MaxOutProcess, + T><<(context) + .stream()>>>(nthreads, input_data, output_data, input_channels, + input_height, input_width, groups, + maxout_process); + } +}; +/* + * All tensors are in NCHW format. + * Ksize, strides, paddings are two elements. These two elements represent + * height and width, respectively. + */ +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, + int groups, int num_channels) { + const int batch_size = input.dims()[0]; + const int input_channels = input.dims()[1]; + const int input_height = input.dims()[2]; + const int input_width = input.dims()[3]; + const int output_channels = output.dims()[1]; + const int output_height = output.dims()[2]; + const int output_width = output.dims()[3]; + + const T* input_data = input.data(); + const T* output_data = output.data(); + const T* output_grad_data = output_grad.data(); + T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + + int nthreads = batch_size * output_channels * output_height * output_width; + int blocks = (nthreads + 1024 - 1) / 1024; + dim3 threads(1024, 1); + dim3 grid(blocks, 1); + + KernelMaxoutGrad< + T><<(context) + .stream()>>>( + nthreads, input_data, output_data, output_grad_data, input_grad_data, + input_channels, input_height, input_width, groups); + } +}; + +template class MaxOutGradFunctor; +template class MaxOutGradFunctor; + +template class MaxOutFunctor, float>; +template class MaxOutFunctor, double>; + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h new file mode 100644 index 0000000000..aeac084944 --- /dev/null +++ b/paddle/operators/math/maxouting.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/tensor.h" +#include "paddle/platform/device_context.h" +#include "paddle/platform/hostdevice.h" + +namespace paddle { +namespace operators { +namespace math { + +#define FLT_MAX \ + __FLT_MAX__ // It might need to be placed in another file, but I'm still + // wondering where to put it. + +/* + * \brief Extracting simple operations from pooling. + * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * operation. + * MaxPool initializes temp variable to the negative maximum to find the + * maximum value in the pooling field. + * AvgPool initializes temp variable to the zero to accumulate all values + * in pool pooling, and finally takes the average. + * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. + */ +template +class MaxOut { + public: + DEVICE inline T initial() { return static_cast(-FLT_MAX); } + DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } + DEVICE inline void finalize(T& y, const T& group) {} +}; + +template +class MaxOutGrad { + public: + DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, + T scale) { + dx += dy * (x == y); + } +}; + + +/* + * \brief Getting pooling results, and calculating gradient. + * + * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the + * number of channels, H and W is the height and width of feature. + * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the + * number of channels, D, H and W is the depth, height and width of feature. + * + * In max pooling, it is possible that the pooling region has multiple maximum + * elements. In this case, we should compute the gradient of the first maximum + * element. + * This is different from average pooling. So we rewrite the max_pool_grad: + * MaxPool2dGradFunctor, MaxPool3dGradFunctor. + */ +template +class MaxOutFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, framework::Tensor& output, + int groups, int num_channels, MaxOutProcess maxout_compute); +}; + + +template +class MaxOutGradFunctor { + public: + void operator()(const platform::DeviceContext& context, + const framework::Tensor& input, + framework::Tensor& input_grad, + const framework::Tensor& output, + const framework::Tensor& output_grad, int groups, + int num_channels); +}; + + + + + + + +} // namespace math +} // namespace operators +} // namespace paddle diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc new file mode 100644 index 0000000000..41b3860a86 --- /dev/null +++ b/paddle/operators/maxout_op.cc @@ -0,0 +1,115 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. */ + + +#include "paddle/operators/maxout_op.h" +namespace paddle { +namespace operators { + +using framework::Tensor; + +/********first define ProtoMakerē±» ***************/ +class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", + "(Tensor) The input tensor of pooling operator. " + "The format of input tensor is NCHW. Where N is batch size, C is the " + "number of channels, H and W is the height and width of feature."); + AddOutput("Out", + "(Tensor) The output tensor of pooling operator." + "The format of output tensor is also NCHW." + "Where N is batch size, C is " + "the number of channels, H and W is the height and " + "width of feature."); + + AddAttr( + "groups", + R"DOC(The group number of input layer. + )DOC") + .SetDefault(2); + AddAttr( + "num_channels", + R"DOC(The channel number of input layer. + )DOC") + .SetDefault(0); + AddComment(R"DOC(A layer to do max out on conv layer output. + - Input: output of a conv layer. + - Output: feature map size same as input. Channel is (input channel) / groups. + So groups should be larger than 1, and the num of channels should be able + to devided by groups. + )DOC"); + } +}; + +/******************2nd **********************************/ + +class MaxOutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" + "should not be null."); + PADDLE_ENFORCE(ctx->HasOutput("Out"), + "Output(Out) of maxoutOp should not be null."); + auto in_x_dims = ctx->GetInputDim("X"); + int groups = ctx->Attrs().Get("groups"); + int num_channels = ctx->Attrs().Get("num_channels"); + + // check groups > 1 + PADDLE_ENFORCE_GT( + groups, 1, + "in maxoutop groups should be larger than 1"); + // check num_channels%groups=0 + PADDLE_ENFORCE_EQ(num_channels % groups, 0, + "the num of channels should be able" + "to devided by groups"); + + int out_num_channels = num_channels / groups; + + std::vector output_shape({in_x_dims[0], out_num_channels}); + output_shape.push_back(in_x_dims[2]); + output_shape.push_back(in_x_dims[3]); + + ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + } +}; + + +class MaxOutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); + PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), + "Input(X@GRAD) should not be null."); + ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); + } +}; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, + ops::MaxOutOpGrad); + + +REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_CPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.cu b/paddle/operators/maxout_op.cu new file mode 100644 index 0000000000..44a149b065 --- /dev/null +++ b/paddle/operators/maxout_op.cu @@ -0,0 +1,23 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/maxout_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_GPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h new file mode 100644 index 0000000000..2321613512 --- /dev/null +++ b/paddle/operators/maxout_op.h @@ -0,0 +1,77 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" +#include "paddle/operators/math/math_function.h" +#include "paddle/operators/math/maxouting.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class MaxOutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + Tensor* out = context.Output("Out"); + + int groups = context.template Attr("groups"); + int num_channels = context.template Attr("num_channels"); + + + paddle::operators::math::MaxOutFunctor< + Place, paddle::operators::math::MaxOut, T> + maxout_forward; + paddle::operators::math::MaxOut maxout_process; + maxout_forward(context.device_context(), *in_x, *out, groups, num_channels, + maxout_process); + } +}; + +template +class MaxOutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + const Tensor* in_x = context.Input("X"); + const Tensor* out = context.Input("Out"); + const Tensor* out_grad = + context.Input(framework::GradVarName("Out")); + Tensor* in_x_grad = context.Output(framework::GradVarName("X")); + + int groups = context.template Attr("groups"); + int num_channels = context.template Attr("num_channels"); + + + + if (in_x_grad) { + in_x_grad->mutable_data(context.GetPlace()); + auto temp = framework::EigenVector::Flatten(*in_x_grad); + temp.device(context.GetEigenDevice()) = + temp.constant(static_cast(0)); + + paddle::operators::math::MaxOutGradFunctor + maxout_backward; + maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, + *out_grad, groups, num_channels); + } + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py new file mode 100644 index 0000000000..4ea1e3c29c --- /dev/null +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -0,0 +1,52 @@ +import unittest +import numpy as np +from op_test import OpTest + + + +def maxout_forward_naive_2sweetsky(input, groups, num_channels): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + +def maxout_forward_naive(input, groups,num_channels): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + + + +class TestMaxOut_Op(OpTest): + def setUp(self): + self.op_type = "maxout" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups, + self.num_channels).astype("float32") + + self.inputs = {'X': input} + self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + print self.inputs + print self.outputs + self.check_grad(['X'], 'Out', max_relative_error=0.5) + + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 6, 2, 2] + self.groups=2 + self.num_channels=6 + + + + +if __name__ == '__main__': + unittest.main() From 058bdd345d317db00b661c0c4fdf4acaca6710f8 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sat, 11 Nov 2017 18:17:01 +0800 Subject: [PATCH 02/15] this for maxout op new add --- paddle/operators/CMakeLists.txt | 4 +++- paddle/operators/math/CMakeLists.txt | 2 ++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/paddle/operators/CMakeLists.txt b/paddle/operators/CMakeLists.txt index 13ebb0ad65..d39f7bf452 100644 --- a/paddle/operators/CMakeLists.txt +++ b/paddle/operators/CMakeLists.txt @@ -96,7 +96,7 @@ function(op_library TARGET) # It's enough to just adding one operator to pybind file(APPEND ${pybind_file} "USE_GPU_ONLY_OP(ncclAllReduce);\n") endif() - + # reduce_op contains several operators if ("${TARGET}" STREQUAL "reduce_op") set(pybind_flag 1) @@ -138,6 +138,7 @@ set(DEPS_OPS softmax_with_cross_entropy_op sum_op pool_op + maxout_op pool_with_index_op nccl_op sequence_conv_op @@ -149,6 +150,7 @@ op_library(cross_entropy_op DEPS cross_entropy) op_library(softmax_with_cross_entropy_op DEPS cross_entropy softmax) op_library(sum_op DEPS net_op selected_rows_functor) op_library(pool_op DEPS pooling) +op_library(maxout_op DEPS maxouting) op_library(pool_with_index_op DEPS pooling) op_library(lod_rank_table_op SRCS lod_rank_table_op.cc DEPS lod_rank_table) if(WITH_GPU) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index 40cc177d0f..b39a64c0f3 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,6 +8,7 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) + nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) @@ -18,6 +19,7 @@ else() cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) + cc_library(maxouting SRCS maxouting.cc DEPS device_context) cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) From bd773b9c8429a64287d840eb5bd297c882b1d9d7 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 14 Nov 2017 14:20:50 +0800 Subject: [PATCH 03/15] modify for maxoutop code review --- paddle/operators/math/CMakeLists.txt | 6 +- paddle/operators/math/maxouting.cc | 25 ++++---- paddle/operators/math/maxouting.cu | 61 ++++++++---------- paddle/operators/math/maxouting.h | 22 +++---- paddle/operators/maxout_op.cc | 63 +++++++++++++------ paddle/operators/maxout_op.h | 7 +-- .../v2/framework/tests/test_maxout_op.py | 13 +--- 7 files changed, 98 insertions(+), 99 deletions(-) diff --git a/paddle/operators/math/CMakeLists.txt b/paddle/operators/math/CMakeLists.txt index fb83b14782..3b4af8e439 100644 --- a/paddle/operators/math/CMakeLists.txt +++ b/paddle/operators/math/CMakeLists.txt @@ -8,24 +8,26 @@ if(WITH_GPU) nv_library(softmax SRCS softmax.cc softmax.cu DEPS operator) nv_library(cross_entropy SRCS cross_entropy.cc cross_entropy.cu DEPS operator) nv_library(pooling SRCS pooling.cc pooling.cu DEPS device_context) - nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) + nv_library(sequence_pooling SRCS sequence_pooling.cc sequence_pooling.cu DEPS device_context math_function) nv_library(vol2col SRCS vol2col.cc vol2col.cu DEPS device_context) nv_library(context_project SRCS context_project.cc context_project.cu DEPS device_context) nv_library(sequence2batch SRCS sequence2batch.cc sequence2batch.cu DEPS device_context) nv_library(lstm_compute SRCS lstm_compute.cc lstm_compute.cu DEPS device_context activation_functions) nv_library(gru_compute SRCS gru_compute.cc gru_compute.cu DEPS device_context activation_functions math_function) + nv_library(maxouting SRCS maxouting.cc maxouting.cu DEPS device_context) else() cc_library(math_function SRCS math_function.cc im2col.cc DEPS cblas device_context operator) cc_library(selected_rows_functor SRCS selected_rows_functor.cc DEPS selected_rows math_function) cc_library(softmax SRCS softmax.cc DEPS operator) cc_library(cross_entropy SRCS cross_entropy.cc DEPS operator) cc_library(pooling SRCS pooling.cc DEPS device_context) - cc_library(maxouting SRCS maxouting.cc DEPS device_context) + cc_library(sequence_pooling SRCS sequence_pooling.cc DEPS device_context math_function) cc_library(vol2col SRCS vol2col.cc DEPS device_context) cc_library(context_project SRCS context_project.cc DEPS device_context) cc_library(sequence2batch SRCS sequence2batch.cc DEPS device_context) cc_library(lstm_compute SRCS lstm_compute.cc DEPS device_context activation_functions) cc_library(gru_compute SRCS gru_compute.cc DEPS device_context activation_functions math_function) + cc_library(maxouting SRCS maxouting.cc DEPS device_context) endif() cc_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index f01fa18391..a634e49f48 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -20,25 +20,27 @@ namespace math { /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. + * groups mustbe > 1 */ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, MaxOutProcess maxout_process) { + const framework::Tensor& input, + framework::Tensor * output, + int groups, + MaxOutProcess maxout_process) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels/groups; + const int output_channels = output->dims()[1]; int fea_size = input_height * input_width; + // c_size mean output one batch size int c_size = fea_size * output_channels; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); + T* output_data = output->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; i++) { int new_bindex = c_size * i; @@ -50,7 +52,6 @@ class MaxOutFunctor { maxout_process.compute(ele, input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); } - maxout_process.finalize(ele, (static_cast(groups))); output_data[(new_bindex+new_cindex+f)] = ele; } } @@ -68,11 +69,11 @@ public: framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, - int groups, int num_channels) { + int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels / groups; + const int output_channels = output.dims()[1]; int fea_size = input_height * input_width; @@ -95,8 +96,6 @@ public: if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; stop = true; - } else { - input_grad_data[input_idx] = 0; } } } @@ -108,9 +107,9 @@ public: template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor, float>; + math::MaxOut, float>; template class MaxOutFunctor, double>; + math::MaxOut, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index b1c0dd8fd4..42acaa2c73 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, T* output_data, const int channels, const int input_height, const int input_width, int groups, MaxOutProcess maxout_process) { - int size = input_height * input_width * channels / groups; - int featLen = input_height * input_width; + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int batch_idx = index / size; - int i = index % size; - int channel_idx = i / featLen; - int feat_idx = i % featLen; + int batch_offset = index % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; int data_idx = - (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; T ele = maxout_process.initial(); - for (int g = 0; g < groups; g++) { - maxout_process.compute(ele, input_data[data_idx + g * featLen]); + for (int g = 0; g < groups; ++g) { + maxout_process.compute(ele, input_data[data_idx + g * feat_len]); } - maxout_process.finalize(ele, (static_cast(groups))); output_data[index] = ele; } } @@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad( const int nthreads, const T* input_data, const T* output_data, const T* output_grad, T* input_grad, const int channels, const int input_height, const int input_width, int groups) { - int size = input_height * input_width * channels / groups; - int featLen = input_height * input_width; + const int size = input_height * input_width * channels / groups; + const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; index += blockDim.x * gridDim.x) { int batch_idx = index / size; - int i = index % size; - int channel_idx = i / featLen; - int feat_idx = i % featLen; + int batch_offset = index % size; + int channel_idx = batch_offset / feat_len; + int feat_idx = batch_offset % feat_len; int data_idx = - (batch_idx * size + channel_idx * featLen) * groups + feat_idx; + (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; int maxIndex = -1; bool stop = false; for (int g = 0; g < groups && !stop; g++) { - if (input_data[data_idx + g * featLen] == output_data[index]) { - maxIndex = data_idx + g * featLen; + if (input_data[data_idx + g * feat_len] == output_data[index]) { + maxIndex = data_idx + g * feat_len; stop = true; } } @@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad( } /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. */ template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, + const framework::Tensor& input, framework::Tensor * output, + int groups, MaxOutProcess maxout_process) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; - const int output_channels = num_channels / groups; - const int output_height = output.dims()[2]; - const int output_width = output.dims()[3]; + const int output_channels = output->dims()[1]; + const int output_height = output->dims()[2]; + const int output_width = output->dims()[3]; const T* input_data = input.data(); - T* output_data = output.mutable_data(context.GetPlace()); - - int nthreads = batch_size * output_channels * output_height * output_width; + T* output_data = output->mutable_data(context.GetPlace()); + int nthreads = output->numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -110,8 +106,6 @@ class MaxOutFunctor { }; /* * All tensors are in NCHW format. - * Ksize, strides, paddings are two elements. These two elements represent - * height and width, respectively. */ template class MaxOutGradFunctor { @@ -120,7 +114,7 @@ class MaxOutGradFunctor { const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, - int groups, int num_channels) { + int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -133,8 +127,7 @@ class MaxOutGradFunctor { const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad.mutable_data(context.GetPlace()); - - int nthreads = batch_size * output_channels * output_height * output_width; + int nthreads = output.numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); dim3 grid(blocks, 1); @@ -152,9 +145,9 @@ template class MaxOutGradFunctor; template class MaxOutGradFunctor; template class MaxOutFunctor, float>; + math::MaxOut, float>; template class MaxOutFunctor, double>; + math::MaxOut, double>; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index aeac084944..6aaa1656a7 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -22,26 +22,20 @@ namespace paddle { namespace operators { namespace math { + #define FLT_MAX \ - __FLT_MAX__ // It might need to be placed in another file, but I'm still - // wondering where to put it. + __FLT_MAX__ /* - * \brief Extracting simple operations from pooling. - * Both MaxPool and AvgPool need "initial", "compute" and "finalize" + * \brief Extracting simple operations from maxout. + * need "initial", "compute" * operation. - * MaxPool initializes temp variable to the negative maximum to find the - * maximum value in the pooling field. - * AvgPool initializes temp variable to the zero to accumulate all values - * in pool pooling, and finally takes the average. - * MaxPoolGrad and AvgPoolGrad are gradient operations respectively. */ template class MaxOut { public: DEVICE inline T initial() { return static_cast(-FLT_MAX); } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } - DEVICE inline void finalize(T& y, const T& group) {} }; template @@ -69,11 +63,12 @@ class MaxOutGrad { * MaxPool2dGradFunctor, MaxPool3dGradFunctor. */ template + class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& output, - int groups, int num_channels, MaxOutProcess maxout_compute); + const framework::Tensor& input, framework::Tensor * output, + int groups, MaxOutProcess maxout_compute); }; @@ -84,8 +79,7 @@ class MaxOutGradFunctor { const framework::Tensor& input, framework::Tensor& input_grad, const framework::Tensor& output, - const framework::Tensor& output_grad, int groups, - int num_channels); + const framework::Tensor& output_grad, int groups); }; diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index 41b3860a86..c54a706979 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -19,17 +19,16 @@ namespace operators { using framework::Tensor; -/********first define ProtoMakerē±» ***************/ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { public: MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddInput("X", - "(Tensor) The input tensor of pooling operator. " + "(Tensor) The input tensor of maxout operator. " "The format of input tensor is NCHW. Where N is batch size, C is the " "number of channels, H and W is the height and width of feature."); AddOutput("Out", - "(Tensor) The output tensor of pooling operator." + "(Tensor) The output tensor of maxout operator." "The format of output tensor is also NCHW." "Where N is batch size, C is " "the number of channels, H and W is the height and " @@ -38,23 +37,53 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { AddAttr( "groups", R"DOC(The group number of input layer. - )DOC") - .SetDefault(2); - AddAttr( - "num_channels", - R"DOC(The channel number of input layer. - )DOC") - .SetDefault(0); - AddComment(R"DOC(A layer to do max out on conv layer output. - - Input: output of a conv layer. + )DOC"); + AddComment(R"DOC( + - Input: NCHW. - Output: feature map size same as input. Channel is (input channel) / groups. So groups should be larger than 1, and the num of channels should be able to devided by groups. + + .. math:: + y_{si+j} = \max_k x_{gsi + sk + j} + g = groups + s = input.size / num_channels + 0 \le i < num_channels / groups + 0 \le j < s + 0 \le k < groups + + Please refer to Paper: + - Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf + - Multi-digit Number Recognition from Street View \ + Imagery using Deep Convolutional Neural Networks: \ + https://arxiv.org/pdf/1312.6082v4.pdf + + The simple usage is: + + .. code-block:: python + + maxout = maxout_layer(input, + num_channels=128, + groups=4) + + :param input: The input of this layer. + :type input: LayerOutput + :param num_channels: The channel number of input layer. If None will be set + automatically from previous output. + :type num_channels: int | None + :param groups: The group number of input layer. + :type groups: int + :param name: The name of this layer. It is optional. + :type name: None | basestring. + :param layer_attr: Extra Layer attribute. + :type layer_attr: ExtraLayerAttribute + :return: LayerOutput object. + :rtype: LayerOutput + )DOC"); } }; -/******************2nd **********************************/ class MaxOutOp : public framework::OperatorWithKernel { public: @@ -67,20 +96,14 @@ class MaxOutOp : public framework::OperatorWithKernel { "Output(Out) of maxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); - int num_channels = ctx->Attrs().Get("num_channels"); // check groups > 1 PADDLE_ENFORCE_GT( groups, 1, "in maxoutop groups should be larger than 1"); - // check num_channels%groups=0 - PADDLE_ENFORCE_EQ(num_channels % groups, 0, - "the num of channels should be able" - "to devided by groups"); - int out_num_channels = num_channels / groups; - std::vector output_shape({in_x_dims[0], out_num_channels}); + std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[3]); diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 2321613512..3f5897abd2 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -14,7 +14,6 @@ limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" #include "paddle/operators/math/math_function.h" #include "paddle/operators/math/maxouting.h" @@ -32,14 +31,13 @@ class MaxOutKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); - int num_channels = context.template Attr("num_channels"); paddle::operators::math::MaxOutFunctor< Place, paddle::operators::math::MaxOut, T> maxout_forward; paddle::operators::math::MaxOut maxout_process; - maxout_forward(context.device_context(), *in_x, *out, groups, num_channels, + maxout_forward(context.device_context(), *in_x, out, groups, maxout_process); } }; @@ -55,7 +53,6 @@ class MaxOutGradKernel : public framework::OpKernel { Tensor* in_x_grad = context.Output(framework::GradVarName("X")); int groups = context.template Attr("groups"); - int num_channels = context.template Attr("num_channels"); @@ -68,7 +65,7 @@ class MaxOutGradKernel : public framework::OpKernel { paddle::operators::math::MaxOutGradFunctor maxout_backward; maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, groups, num_channels); + *out_grad, groups); } } }; diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py index 4ea1e3c29c..406147ef24 100644 --- a/python/paddle/v2/framework/tests/test_maxout_op.py +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -3,22 +3,13 @@ import numpy as np from op_test import OpTest - -def maxout_forward_naive_2sweetsky(input, groups, num_channels): - s0, s1, s2, s3 = input.shape - return np.ndarray([s0, s1 / groups, groups, s2, s3], \ - buffer = input, dtype=input.dtype).max(axis=(2)) - - def maxout_forward_naive(input, groups,num_channels): s0, s1, s2, s3 = input.shape return np.ndarray([s0, s1 / groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) - - -class TestMaxOut_Op(OpTest): +class TestMaxOutOp(OpTest): def setUp(self): self.op_type = "maxout" self.init_test_case() @@ -37,7 +28,7 @@ class TestMaxOut_Op(OpTest): def test_check_grad(self): print self.inputs print self.outputs - self.check_grad(['X'], 'Out', max_relative_error=0.5) + self.check_grad(['X'], 'Out') def init_test_case(self): self.MaxOut_forward_naive = maxout_forward_naive From f57cd1e0f9a9a263e12df1cf0c5273e975299a33 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Tue, 14 Nov 2017 18:06:48 +0800 Subject: [PATCH 04/15] del a err comments --- paddle/operators/math/maxouting.h | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index 6aaa1656a7..a8e91a25b5 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -48,20 +48,6 @@ class MaxOutGrad { }; -/* - * \brief Getting pooling results, and calculating gradient. - * - * In pool2d, all tensors are in NCHW format. Where N is batch size, C is the - * number of channels, H and W is the height and width of feature. - * In pool3d, all tensors are in NCDHW format. Where N is batch size, C is the - * number of channels, D, H and W is the depth, height and width of feature. - * - * In max pooling, it is possible that the pooling region has multiple maximum - * elements. In this case, we should compute the gradient of the first maximum - * element. - * This is different from average pooling. So we rewrite the max_pool_grad: - * MaxPool2dGradFunctor, MaxPool3dGradFunctor. - */ template class MaxOutFunctor { From 8d9babf20407d1ea21ad66cf5c07ec61adb7398d Mon Sep 17 00:00:00 2001 From: wanghaox Date: Wed, 15 Nov 2017 15:47:00 +0800 Subject: [PATCH 05/15] maxout code review 2nd --- paddle/operators/math/maxouting.cc | 10 +++++----- paddle/operators/math/maxouting.cu | 11 ++++++----- paddle/operators/maxout_op.h | 8 +++----- python/paddle/v2/framework/tests/test_maxout_op.py | 2 -- 4 files changed, 14 insertions(+), 17 deletions(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index a634e49f48..b733af7410 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -42,11 +42,11 @@ class MaxOutFunctor { const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; i++) { + for (int i = 0; i < batch_size; ++i) { int new_bindex = c_size * i; for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; - for (int f = 0; f < fea_size; f++) { + for (int f = 0; f < fea_size; ++f) { T ele = maxout_process.initial(); for (int ph = 0; ph < groups; ++ph) { maxout_process.compute(ele, @@ -82,15 +82,15 @@ public: const T* output_grad_data = output_grad.data(); T* input_grad_data = input_grad.mutable_data(context.GetPlace()); - for (int i = 0; i < batch_size; i++) { + for (int i = 0; i < batch_size; ++i) { int blen = fea_size * output_channels * i; for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; - for (int f = 0; f < fea_size; f++) { + for (int f = 0; f < fea_size; ++f) { int input_idx = 0; bool stop = false; int output_idx = blen + clen + f; - for (int g = 0; g < groups && !stop; g++) { + for (int g = 0; g < groups && !stop; ++g) { input_idx = (blen + clen) * groups + fea_size * g + f; input_grad_data[input_idx] = 0; if (input_data[input_idx] == output_data[output_idx]) { diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 42acaa2c73..c2da29e356 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -21,9 +21,10 @@ namespace math { template __global__ void KernelMaxOut(const int nthreads, const T* input_data, - T* output_data, const int channels, + const int channels, const int input_height, const int input_width, - int groups, MaxOutProcess maxout_process) { + int groups, T* output_data, + MaxOutProcess maxout_process) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; @@ -58,7 +59,7 @@ __global__ void KernelMaxoutGrad( (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; int maxIndex = -1; bool stop = false; - for (int g = 0; g < groups && !stop; g++) { + for (int g = 0; g < groups && !stop; ++g) { if (input_data[data_idx + g * feat_len] == output_data[index]) { maxIndex = data_idx + g * feat_len; stop = true; @@ -99,9 +100,9 @@ class MaxOutFunctor { MaxOutProcess, T><<(context) - .stream()>>>(nthreads, input_data, output_data, input_channels, + .stream()>>>(nthreads, input_data, input_channels, input_height, input_width, groups, - maxout_process); + output_data, maxout_process); } }; /* diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 3f5897abd2..aab878af0f 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -54,13 +54,11 @@ class MaxOutGradKernel : public framework::OpKernel { int groups = context.template Attr("groups"); - - + auto& device_ctx = context.device_context(); + math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); - auto temp = framework::EigenVector::Flatten(*in_x_grad); - temp.device(context.GetEigenDevice()) = - temp.constant(static_cast(0)); + zero(device_ctx, in_x_grad, static_cast(0.0)); paddle::operators::math::MaxOutGradFunctor maxout_backward; diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py index 406147ef24..a7c47108f1 100644 --- a/python/paddle/v2/framework/tests/test_maxout_op.py +++ b/python/paddle/v2/framework/tests/test_maxout_op.py @@ -26,8 +26,6 @@ class TestMaxOutOp(OpTest): self.check_output() def test_check_grad(self): - print self.inputs - print self.outputs self.check_grad(['X'], 'Out') def init_test_case(self): From 5802880bbc7a4dec64a2dee1422c6fd6f3e4c3f9 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sun, 19 Nov 2017 16:51:39 +0800 Subject: [PATCH 06/15] update maxoutop for code review 3 --- paddle/operators/math/maxouting.cc | 36 ++++++++--------- paddle/operators/math/maxouting.cu | 62 +++++++++++++++--------------- paddle/operators/math/maxouting.h | 36 +---------------- paddle/operators/maxout_op.cc | 43 +++------------------ paddle/operators/maxout_op.h | 11 +----- 5 files changed, 54 insertions(+), 134 deletions(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index b733af7410..baaa86ffce 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -22,23 +22,20 @@ namespace math { * All tensors are in NCHW format. * groups mustbe > 1 */ -template -class MaxOutFunctor { +template +class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, - MaxOutProcess maxout_process) { + int groups) { const int batch_size = input.dims()[0]; const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output->dims()[1]; - int fea_size = input_height * input_width; - // c_size mean output one batch size + // c_size means the output size of each sample int c_size = fea_size * output_channels; - const T* input_data = input.data(); T* output_data = output->mutable_data(context.GetPlace()); @@ -47,10 +44,11 @@ class MaxOutFunctor { for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { - T ele = maxout_process.initial(); + // T ele = maxout_process.initial(); + T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { - maxout_process.compute(ele, - input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]); + T x=input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; + ele = ele > x ? ele : x; } output_data[(new_bindex+new_cindex+f)] = ele; } @@ -74,9 +72,7 @@ public: const int input_height = input.dims()[2]; const int input_width = input.dims()[3]; const int output_channels = output.dims()[1]; - int fea_size = input_height * input_width; - const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); @@ -87,15 +83,15 @@ public: for (int c = 0; c < output_channels; ++c) { int clen = fea_size * c; for (int f = 0; f < fea_size; ++f) { - int input_idx = 0; - bool stop = false; + int input_idx0 = (blen + clen) * groups + f; + bool continue_match = true; int output_idx = blen + clen + f; - for (int g = 0; g < groups && !stop; ++g) { - input_idx = (blen + clen) * groups + fea_size * g + f; + for (int g = 0; g < groups && continue_match; ++g) { + int input_idx = input_idx0 + fea_size * g; input_grad_data[input_idx] = 0; if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; - stop = true; + continue_match = false; } } } @@ -106,10 +102,8 @@ public: template class MaxOutGradFunctor; template class MaxOutGradFunctor; -template class MaxOutFunctor, float>; -template class MaxOutFunctor, double>; +template class MaxOutFunctor; +template class MaxOutFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index c2da29e356..1a8fc465cc 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -19,27 +19,28 @@ namespace paddle { namespace operators { namespace math { -template +template __global__ void KernelMaxOut(const int nthreads, const T* input_data, const int channels, const int input_height, const int input_width, - int groups, T* output_data, - MaxOutProcess maxout_process) { + int groups, T* output_data ) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int batch_idx = index / size; - int batch_offset = index % size; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; int channel_idx = batch_offset / feat_len; int feat_idx = batch_offset % feat_len; int data_idx = (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; - T ele = maxout_process.initial(); + T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { - maxout_process.compute(ele, input_data[data_idx + g * feat_len]); + T x=input_data[data_idx + g * feat_len]; + ele = ele > x ? ele : x; } - output_data[index] = ele; + output_data[i] = ele; } } template @@ -49,38 +50,38 @@ __global__ void KernelMaxoutGrad( const int input_height, const int input_width, int groups) { const int size = input_height * input_width * channels / groups; const int feat_len = input_height * input_width; - for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads; - index += blockDim.x * gridDim.x) { - int batch_idx = index / size; - int batch_offset = index % size; + int index = blockIdx.x * blockDim.x + threadIdx.x; + int offset = blockDim.x * gridDim.x; + for (int i = index; i < nthreads; i += offset) { + int batch_idx = i / size; + int batch_offset = i % size; int channel_idx = batch_offset / feat_len; int feat_idx = batch_offset % feat_len; int data_idx = (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; - int maxIndex = -1; - bool stop = false; - for (int g = 0; g < groups && !stop; ++g) { - if (input_data[data_idx + g * feat_len] == output_data[index]) { - maxIndex = data_idx + g * feat_len; - stop = true; + int max_index = -1; + bool continue_match = true; + for (int g = 0; g < groups && continue_match; ++g) { + if (input_data[data_idx + g * feat_len] == output_data[i]) { + max_index = data_idx + g * feat_len; + continue_match = false; } } - if (maxIndex != -1) { + if (max_index != -1) { // atomic add - platform::CudaAtomicAdd(input_grad + maxIndex, output_grad[index]); + platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]); } } } /* * All tensors are in NCHW format. */ -template -class MaxOutFunctor { +template +class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, - MaxOutProcess maxout_process) { + int groups) { const int batch_size = input.dims()[0]; const int input_channels = input.dims()[1]; const int input_height = input.dims()[2]; @@ -97,12 +98,11 @@ class MaxOutFunctor { dim3 grid(blocks, 1); KernelMaxOut< - MaxOutProcess, T><<(context) .stream()>>>(nthreads, input_data, input_channels, input_height, input_width, groups, - output_data, maxout_process); + output_data); } }; /* @@ -145,10 +145,8 @@ class MaxOutGradFunctor { template class MaxOutGradFunctor; template class MaxOutGradFunctor; -template class MaxOutFunctor, float>; -template class MaxOutFunctor, double>; +template class MaxOutFunctor; +template class MaxOutFunctor; } // namespace math } // namespace operators diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index a8e91a25b5..72f40d96f7 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once -#include "paddle/framework/eigen.h" #include "paddle/framework/tensor.h" #include "paddle/platform/device_context.h" #include "paddle/platform/hostdevice.h" @@ -22,42 +21,18 @@ namespace paddle { namespace operators { namespace math { - #define FLT_MAX \ __FLT_MAX__ -/* - * \brief Extracting simple operations from maxout. - * need "initial", "compute" - * operation. - */ -template -class MaxOut { - public: - DEVICE inline T initial() { return static_cast(-FLT_MAX); } - DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } -}; - -template -class MaxOutGrad { - public: - DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, - T scale) { - dx += dy * (x == y); - } -}; - - -template +template class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups, MaxOutProcess maxout_compute); + int groups ); }; - template class MaxOutGradFunctor { public: @@ -67,13 +42,6 @@ class MaxOutGradFunctor { const framework::Tensor& output, const framework::Tensor& output_grad, int groups); }; - - - - - - - } // namespace math } // namespace operators } // namespace paddle diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index c54a706979..f9277518cc 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -12,7 +12,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #include "paddle/operators/maxout_op.h" namespace paddle { namespace operators { @@ -33,18 +32,18 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { "Where N is batch size, C is " "the number of channels, H and W is the height and " "width of feature."); - AddAttr( "groups", R"DOC(The group number of input layer. )DOC"); AddComment(R"DOC( - Input: NCHW. - - Output: feature map size same as input. Channel is (input channel) / groups. + - Output: The feature map size of output is the same as the input. + The output_channel is (input channel) / groups So groups should be larger than 1, and the num of channels should be able - to devided by groups. + to be devided by groups. - .. math:: + math: y_{si+j} = \max_k x_{gsi + sk + j} g = groups s = input.size / num_channels @@ -57,29 +56,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { - Multi-digit Number Recognition from Street View \ Imagery using Deep Convolutional Neural Networks: \ https://arxiv.org/pdf/1312.6082v4.pdf - - The simple usage is: - - .. code-block:: python - - maxout = maxout_layer(input, - num_channels=128, - groups=4) - - :param input: The input of this layer. - :type input: LayerOutput - :param num_channels: The channel number of input layer. If None will be set - automatically from previous output. - :type num_channels: int | None - :param groups: The group number of input layer. - :type groups: int - :param name: The name of this layer. It is optional. - :type name: None | basestring. - :param layer_attr: Extra Layer attribute. - :type layer_attr: ExtraLayerAttribute - :return: LayerOutput object. - :rtype: LayerOutput - )DOC"); } }; @@ -88,7 +64,6 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" "should not be null."); @@ -96,26 +71,20 @@ class MaxOutOp : public framework::OperatorWithKernel { "Output(Out) of maxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); - // check groups > 1 PADDLE_ENFORCE_GT( groups, 1, - "in maxoutop groups should be larger than 1"); - - + "groups should be larger than 1 in maxoutop"); std::vector output_shape({in_x_dims[0], in_x_dims[1] / groups}); output_shape.push_back(in_x_dims[2]); output_shape.push_back(in_x_dims[3]); - ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); } }; - class MaxOutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null."); PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), @@ -129,8 +98,6 @@ class MaxOutOpGrad : public framework::OperatorWithKernel { namespace ops = paddle::operators; REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad, ops::MaxOutOpGrad); - - REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel); REGISTER_OP_CPU_KERNEL(maxout_grad, diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index aab878af0f..6c769838c3 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -29,16 +29,12 @@ class MaxOutKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { const Tensor* in_x = context.Input("X"); Tensor* out = context.Output("Out"); - int groups = context.template Attr("groups"); - paddle::operators::math::MaxOutFunctor< - Place, paddle::operators::math::MaxOut, T> + Place, T> maxout_forward; - paddle::operators::math::MaxOut maxout_process; - maxout_forward(context.device_context(), *in_x, out, groups, - maxout_process); + maxout_forward(context.device_context(), *in_x, out, groups); } }; @@ -51,15 +47,12 @@ class MaxOutGradKernel : public framework::OpKernel { const Tensor* out_grad = context.Input(framework::GradVarName("Out")); Tensor* in_x_grad = context.Output(framework::GradVarName("X")); - int groups = context.template Attr("groups"); - auto& device_ctx = context.device_context(); math::SetConstant zero; if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); - paddle::operators::math::MaxOutGradFunctor maxout_backward; maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, From a6a01c15f5049b56d48dcf8a146b6825fcb0c248 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Sun, 19 Nov 2017 17:47:48 +0800 Subject: [PATCH 07/15] add test_maxout_op framework to fluis --- .../paddle/v2/fluid/tests/test_maxout_op.py | 41 +++++++++++++++++++ 1 file changed, 41 insertions(+) create mode 100644 python/paddle/v2/fluid/tests/test_maxout_op.py diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py new file mode 100644 index 0000000000..a7c47108f1 --- /dev/null +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -0,0 +1,41 @@ +import unittest +import numpy as np +from op_test import OpTest + + +def maxout_forward_naive(input, groups,num_channels): + s0, s1, s2, s3 = input.shape + return np.ndarray([s0, s1 / groups, groups, s2, s3], \ + buffer = input, dtype=input.dtype).max(axis=(2)) + + +class TestMaxOutOp(OpTest): + def setUp(self): + self.op_type = "maxout" + self.init_test_case() + input = np.random.random(self.shape).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups, + self.num_channels).astype("float32") + + self.inputs = {'X': input} + self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} + + self.outputs = {'Out': output.astype('float32')} + + def test_check_output(self): + self.check_output() + + def test_check_grad(self): + self.check_grad(['X'], 'Out') + + def init_test_case(self): + self.MaxOut_forward_naive = maxout_forward_naive + self.shape = [100, 6, 2, 2] + self.groups=2 + self.num_channels=6 + + + + +if __name__ == '__main__': + unittest.main() From 25d76bc7e147d7cef53a1704c81de4b7d07d0f5f Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 20 Nov 2017 10:18:09 +0800 Subject: [PATCH 08/15] modify for add a space in maxout op --- paddle/operators/math/maxouting.cc | 2 +- paddle/operators/math/maxouting.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index baaa86ffce..aa8d44d2ff 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -47,7 +47,7 @@ class MaxOutFunctor { // T ele = maxout_process.initial(); T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { - T x=input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; + T x = input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; ele = ele > x ? ele : x; } output_data[(new_bindex+new_cindex+f)] = ele; diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index 72f40d96f7..76a256add9 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -30,7 +30,7 @@ class MaxOutFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, framework::Tensor * output, - int groups ); + int groups); }; template From 2d7a652869da626c6418328e5786f1335fb63c1a Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 20 Nov 2017 10:29:09 +0800 Subject: [PATCH 09/15] del framework test_maxout_op --- paddle/operators/math/maxouting.cu | 2 +- .../v2/framework/tests/test_maxout_op.py | 41 ------------------- 2 files changed, 1 insertion(+), 42 deletions(-) delete mode 100644 python/paddle/v2/framework/tests/test_maxout_op.py diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 1a8fc465cc..336a1bd8b5 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -37,7 +37,7 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data, (batch_idx * size + channel_idx * feat_len) * groups + feat_idx; T ele = static_cast(-FLT_MAX); for (int g = 0; g < groups; ++g) { - T x=input_data[data_idx + g * feat_len]; + T x = input_data[data_idx + g * feat_len]; ele = ele > x ? ele : x; } output_data[i] = ele; diff --git a/python/paddle/v2/framework/tests/test_maxout_op.py b/python/paddle/v2/framework/tests/test_maxout_op.py deleted file mode 100644 index a7c47108f1..0000000000 --- a/python/paddle/v2/framework/tests/test_maxout_op.py +++ /dev/null @@ -1,41 +0,0 @@ -import unittest -import numpy as np -from op_test import OpTest - - -def maxout_forward_naive(input, groups,num_channels): - s0, s1, s2, s3 = input.shape - return np.ndarray([s0, s1 / groups, groups, s2, s3], \ - buffer = input, dtype=input.dtype).max(axis=(2)) - - -class TestMaxOutOp(OpTest): - def setUp(self): - self.op_type = "maxout" - self.init_test_case() - input = np.random.random(self.shape).astype("float32") - output = self.MaxOut_forward_naive(input, self.groups, - self.num_channels).astype("float32") - - self.inputs = {'X': input} - self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} - - self.outputs = {'Out': output.astype('float32')} - - def test_check_output(self): - self.check_output() - - def test_check_grad(self): - self.check_grad(['X'], 'Out') - - def init_test_case(self): - self.MaxOut_forward_naive = maxout_forward_naive - self.shape = [100, 6, 2, 2] - self.groups=2 - self.num_channels=6 - - - - -if __name__ == '__main__': - unittest.main() From c645d065fedd691ec1bd5782a5fcf34f6e355055 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 20 Nov 2017 13:24:56 +0800 Subject: [PATCH 10/15] add a space + * --- paddle/operators/math/maxouting.cc | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index aa8d44d2ff..a4d46ccc98 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -47,7 +47,8 @@ class MaxOutFunctor { // T ele = maxout_process.initial(); T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { - T x = input_data[(new_bindex+new_cindex) * groups+ph*fea_size+f]; + T x = input_data[(new_bindex + new_cindex) * groups + + ph * fea_size + f]; ele = ele > x ? ele : x; } output_data[(new_bindex+new_cindex+f)] = ele; From 76fc1a82e109737d704b11d897b83b5f5138bc86 Mon Sep 17 00:00:00 2001 From: wanghaox Date: Mon, 20 Nov 2017 14:33:28 +0800 Subject: [PATCH 11/15] for code review 4 --- paddle/operators/math/maxouting.cc | 10 +++------- .../math/{maxouting.cu => maxouting.cu.cc} | 5 +++-- paddle/operators/math/maxouting.h | 2 +- paddle/operators/maxout_op.cc | 15 +++++++-------- .../operators/{maxout_op.cu => maxout_op.cu.cc} | 1 - paddle/operators/maxout_op.h | 11 ++++------- python/paddle/v2/fluid/tests/test_maxout_op.py | 5 ++--- 7 files changed, 20 insertions(+), 29 deletions(-) rename paddle/operators/math/{maxouting.cu => maxouting.cu.cc} (97%) rename paddle/operators/{maxout_op.cu => maxout_op.cu.cc} (97%) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index a4d46ccc98..c8c1974f79 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -18,10 +18,7 @@ namespace paddle { namespace operators { namespace math { -/* - * All tensors are in NCHW format. - * groups mustbe > 1 - */ +// All tensors are in NCHW format, and the groups must be greater than 1 template class MaxOutFunctor { public: @@ -44,7 +41,6 @@ class MaxOutFunctor { for (int c = 0; c < output_channels; ++c) { int new_cindex = fea_size * c; for (int f = 0; f < fea_size; ++f) { - // T ele = maxout_process.initial(); T ele = static_cast(-FLT_MAX); for (int ph = 0; ph < groups; ++ph) { T x = input_data[(new_bindex + new_cindex) * groups @@ -65,7 +61,7 @@ class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, - framework::Tensor& input_grad, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups) { @@ -77,7 +73,7 @@ public: const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); for (int i = 0; i < batch_size; ++i) { int blen = fea_size * output_channels * i; diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu.cc similarity index 97% rename from paddle/operators/math/maxouting.cu rename to paddle/operators/math/maxouting.cu.cc index 336a1bd8b5..3a0600fd84 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu.cc @@ -112,7 +112,8 @@ template class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, - const framework::Tensor& input, framework::Tensor& input_grad, + const framework::Tensor& input, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups) { @@ -127,7 +128,7 @@ class MaxOutGradFunctor { const T* input_data = input.data(); const T* output_data = output.data(); const T* output_grad_data = output_grad.data(); - T* input_grad_data = input_grad.mutable_data(context.GetPlace()); + T* input_grad_data = input_grad->mutable_data(context.GetPlace()); int nthreads = output.numel(); int blocks = (nthreads + 1024 - 1) / 1024; dim3 threads(1024, 1); diff --git a/paddle/operators/math/maxouting.h b/paddle/operators/math/maxouting.h index 76a256add9..d4c9da38ab 100644 --- a/paddle/operators/math/maxouting.h +++ b/paddle/operators/math/maxouting.h @@ -38,7 +38,7 @@ class MaxOutGradFunctor { public: void operator()(const platform::DeviceContext& context, const framework::Tensor& input, - framework::Tensor& input_grad, + framework::Tensor * input_grad, const framework::Tensor& output, const framework::Tensor& output_grad, int groups); }; diff --git a/paddle/operators/maxout_op.cc b/paddle/operators/maxout_op.cc index f9277518cc..95467f2e69 100644 --- a/paddle/operators/maxout_op.cc +++ b/paddle/operators/maxout_op.cc @@ -34,14 +34,13 @@ class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker { "width of feature."); AddAttr( "groups", - R"DOC(The group number of input layer. + R"DOC("Specifies how many groups the input tensor will be split" + "in the channel dimension. And the number of output channel is " + "the number of channels divided by groups.." )DOC"); AddComment(R"DOC( - - Input: NCHW. - - Output: The feature map size of output is the same as the input. - The output_channel is (input channel) / groups - So groups should be larger than 1, and the num of channels should be able - to be devided by groups. + Assumed the input shape is (N, Ci, H, W). + The output shape is (N, Co, H, W). Then `Co = Ci / groups`. math: y_{si+j} = \max_k x_{gsi + sk + j} @@ -65,10 +64,10 @@ class MaxOutOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { - PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of maxoutOp" + PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp" "should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of maxoutOp should not be null."); + "Output(Out) of MaxoutOp should not be null."); auto in_x_dims = ctx->GetInputDim("X"); int groups = ctx->Attrs().Get("groups"); // check groups > 1 diff --git a/paddle/operators/maxout_op.cu b/paddle/operators/maxout_op.cu.cc similarity index 97% rename from paddle/operators/maxout_op.cu rename to paddle/operators/maxout_op.cu.cc index 44a149b065..3e6debf699 100644 --- a/paddle/operators/maxout_op.cu +++ b/paddle/operators/maxout_op.cu.cc @@ -12,7 +12,6 @@ See the License for the specific language governing permissions and limitations under the License. */ -#define EIGEN_USE_GPU #include "paddle/operators/maxout_op.h" namespace ops = paddle::operators; diff --git a/paddle/operators/maxout_op.h b/paddle/operators/maxout_op.h index 6c769838c3..c404cd16a9 100644 --- a/paddle/operators/maxout_op.h +++ b/paddle/operators/maxout_op.h @@ -31,9 +31,7 @@ class MaxOutKernel : public framework::OpKernel { Tensor* out = context.Output("Out"); int groups = context.template Attr("groups"); - paddle::operators::math::MaxOutFunctor< - Place, T> - maxout_forward; + math::MaxOutFunctor maxout_forward; maxout_forward(context.device_context(), *in_x, out, groups); } }; @@ -53,10 +51,9 @@ class MaxOutGradKernel : public framework::OpKernel { if (in_x_grad) { in_x_grad->mutable_data(context.GetPlace()); zero(device_ctx, in_x_grad, static_cast(0.0)); - paddle::operators::math::MaxOutGradFunctor - maxout_backward; - maxout_backward(context.device_context(), *in_x, *in_x_grad, *out, - *out_grad, groups); + math::MaxOutGradFunctor maxout_backward; + maxout_backward(context.device_context(), *in_x, in_x_grad, *out, + *out_grad, groups); } } }; diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py index a7c47108f1..1416e13feb 100644 --- a/python/paddle/v2/fluid/tests/test_maxout_op.py +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -3,7 +3,7 @@ import numpy as np from op_test import OpTest -def maxout_forward_naive(input, groups,num_channels): +def maxout_forward_naive(input, groups): s0, s1, s2, s3 = input.shape return np.ndarray([s0, s1 / groups, groups, s2, s3], \ buffer = input, dtype=input.dtype).max(axis=(2)) @@ -18,7 +18,7 @@ class TestMaxOutOp(OpTest): self.num_channels).astype("float32") self.inputs = {'X': input} - self.attrs = {'groups': self.groups, 'num_channels': self.num_channels} + self.attrs = {'groups': self.groups} self.outputs = {'Out': output.astype('float32')} @@ -32,7 +32,6 @@ class TestMaxOutOp(OpTest): self.MaxOut_forward_naive = maxout_forward_naive self.shape = [100, 6, 2, 2] self.groups=2 - self.num_channels=6 From 4e5c989669a5ad8c73d638f09f2cb6664763fd4b Mon Sep 17 00:00:00 2001 From: sweetsky0901 <32288640+sweetsky0901@users.noreply.github.com> Date: Mon, 20 Nov 2017 15:25:45 +0800 Subject: [PATCH 12/15] rename back --- paddle/operators/math/{maxouting.cu.cc => maxouting.cu} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename paddle/operators/math/{maxouting.cu.cc => maxouting.cu} (100%) diff --git a/paddle/operators/math/maxouting.cu.cc b/paddle/operators/math/maxouting.cu similarity index 100% rename from paddle/operators/math/maxouting.cu.cc rename to paddle/operators/math/maxouting.cu From 3fbff1ee787bdcf9dd653fa7ea7f3e3732c5423f Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Mon, 20 Nov 2017 16:41:14 +0800 Subject: [PATCH 13/15] for code review 5 --- paddle/operators/math/maxouting.cc | 1 + paddle/operators/math/maxouting.cu | 1 + paddle/operators/maxout_op.cu.cc | 5 +++++ 3 files changed, 7 insertions(+) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index c8c1974f79..bcd4da612c 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -89,6 +89,7 @@ public: if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; continue_match = false; + break; } } } diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 3a0600fd84..0a8afbbaca 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -65,6 +65,7 @@ __global__ void KernelMaxoutGrad( if (input_data[data_idx + g * feat_len] == output_data[i]) { max_index = data_idx + g * feat_len; continue_match = false; + break; } } if (max_index != -1) { diff --git a/paddle/operators/maxout_op.cu.cc b/paddle/operators/maxout_op.cu.cc index 3e6debf699..5ee431cb26 100644 --- a/paddle/operators/maxout_op.cu.cc +++ b/paddle/operators/maxout_op.cu.cc @@ -17,6 +17,11 @@ namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); REGISTER_OP_GPU_KERNEL(maxout_grad, ops::MaxOutGradKernel); +REGISTER_OP_GPU_KERNEL(maxout_grad, + ops::MaxOutGradKernel); From 04fd98930b53e587f95f3ba5dc7f5999472cde00 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Mon, 20 Nov 2017 18:24:19 +0800 Subject: [PATCH 14/15] for code review 6 --- paddle/operators/math/maxouting.cc | 2 -- paddle/operators/math/maxouting.cu | 3 +-- paddle/operators/maxout_op.cu.cc | 12 +++++------- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/paddle/operators/math/maxouting.cc b/paddle/operators/math/maxouting.cc index bcd4da612c..e5168ce7af 100644 --- a/paddle/operators/math/maxouting.cc +++ b/paddle/operators/math/maxouting.cc @@ -85,11 +85,9 @@ public: int output_idx = blen + clen + f; for (int g = 0; g < groups && continue_match; ++g) { int input_idx = input_idx0 + fea_size * g; - input_grad_data[input_idx] = 0; if (input_data[input_idx] == output_data[output_idx]) { input_grad_data[input_idx] += output_grad_data[output_idx]; continue_match = false; - break; } } } diff --git a/paddle/operators/math/maxouting.cu b/paddle/operators/math/maxouting.cu index 0a8afbbaca..7c698577b8 100644 --- a/paddle/operators/math/maxouting.cu +++ b/paddle/operators/math/maxouting.cu @@ -69,8 +69,7 @@ __global__ void KernelMaxoutGrad( } } if (max_index != -1) { - // atomic add - platform::CudaAtomicAdd(input_grad + max_index, output_grad[index]); + input_grad[max_index] += output_grad[index]; } } } diff --git a/paddle/operators/maxout_op.cu.cc b/paddle/operators/maxout_op.cu.cc index 5ee431cb26..a5823fba68 100644 --- a/paddle/operators/maxout_op.cu.cc +++ b/paddle/operators/maxout_op.cu.cc @@ -15,13 +15,11 @@ #include "paddle/operators/maxout_op.h" namespace ops = paddle::operators; -REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); -REGISTER_OP_GPU_KERNEL(maxout, ops::MaxOutKernel); +REGISTER_OP_GPU_KERNEL(maxout, + ops::MaxOutKernel, + ops::MaxOutKernel); REGISTER_OP_GPU_KERNEL(maxout_grad, ops::MaxOutGradKernel); -REGISTER_OP_GPU_KERNEL(maxout_grad, + float>, ops::MaxOutGradKernel); + double>); From 9cb2ff6a3b473c4f930effbb6ec4d4e856676ad3 Mon Sep 17 00:00:00 2001 From: sweetsky0901 Date: Mon, 20 Nov 2017 19:40:25 +0800 Subject: [PATCH 15/15] del num_channels --- python/paddle/v2/fluid/tests/test_maxout_op.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/python/paddle/v2/fluid/tests/test_maxout_op.py b/python/paddle/v2/fluid/tests/test_maxout_op.py index 1416e13feb..05e42f3158 100644 --- a/python/paddle/v2/fluid/tests/test_maxout_op.py +++ b/python/paddle/v2/fluid/tests/test_maxout_op.py @@ -14,8 +14,7 @@ class TestMaxOutOp(OpTest): self.op_type = "maxout" self.init_test_case() input = np.random.random(self.shape).astype("float32") - output = self.MaxOut_forward_naive(input, self.groups, - self.num_channels).astype("float32") + output = self.MaxOut_forward_naive(input, self.groups).astype("float32") self.inputs = {'X': input} self.attrs = {'groups': self.groups}