commit
b80cdcea11
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/maxouting.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
// All tensors are in NCHW format, and the groups must be greater than 1
|
||||
template <typename T>
|
||||
class MaxOutFunctor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * output,
|
||||
int groups) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output->dims()[1];
|
||||
int fea_size = input_height * input_width;
|
||||
// c_size means the output size of each sample
|
||||
int c_size = fea_size * output_channels;
|
||||
const T* input_data = input.data<T>();
|
||||
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
int new_bindex = c_size * i;
|
||||
for (int c = 0; c < output_channels; ++c) {
|
||||
int new_cindex = fea_size * c;
|
||||
for (int f = 0; f < fea_size; ++f) {
|
||||
T ele = static_cast<T>(-FLT_MAX);
|
||||
for (int ph = 0; ph < groups; ++ph) {
|
||||
T x = input_data[(new_bindex + new_cindex) * groups
|
||||
+ ph * fea_size + f];
|
||||
ele = ele > x ? ele : x;
|
||||
}
|
||||
output_data[(new_bindex+new_cindex+f)] = ele;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
|
||||
template <class T>
|
||||
class MaxOutGradFunctor<platform::CPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad,
|
||||
int groups) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output.dims()[1];
|
||||
int fea_size = input_height * input_width;
|
||||
const T* input_data = input.data<T>();
|
||||
const T* output_data = output.data<T>();
|
||||
const T* output_grad_data = output_grad.data<T>();
|
||||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||
|
||||
for (int i = 0; i < batch_size; ++i) {
|
||||
int blen = fea_size * output_channels * i;
|
||||
for (int c = 0; c < output_channels; ++c) {
|
||||
int clen = fea_size * c;
|
||||
for (int f = 0; f < fea_size; ++f) {
|
||||
int input_idx0 = (blen + clen) * groups + f;
|
||||
bool continue_match = true;
|
||||
int output_idx = blen + clen + f;
|
||||
for (int g = 0; g < groups && continue_match; ++g) {
|
||||
int input_idx = input_idx0 + fea_size * g;
|
||||
if (input_data[input_idx] == output_data[output_idx]) {
|
||||
input_grad_data[input_idx] += output_grad_data[output_idx];
|
||||
continue_match = false;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template class MaxOutGradFunctor<platform::CPUPlace, float>;
|
||||
template class MaxOutGradFunctor<platform::CPUPlace, double>;
|
||||
template class MaxOutFunctor<platform::CPUPlace, float>;
|
||||
template class MaxOutFunctor<platform::CPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,154 @@
|
||||
/* Copyright (c) 2016 paddlepaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/math/maxouting.h"
|
||||
#include "paddle/platform/cuda_helper.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <typename T>
|
||||
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
|
||||
const int channels,
|
||||
const int input_height, const int input_width,
|
||||
int groups, T* output_data ) {
|
||||
const int size = input_height * input_width * channels / groups;
|
||||
const int feat_len = input_height * input_width;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int offset = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < nthreads; i += offset) {
|
||||
int batch_idx = i / size;
|
||||
int batch_offset = i % size;
|
||||
int channel_idx = batch_offset / feat_len;
|
||||
int feat_idx = batch_offset % feat_len;
|
||||
int data_idx =
|
||||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
||||
T ele = static_cast<T>(-FLT_MAX);
|
||||
for (int g = 0; g < groups; ++g) {
|
||||
T x = input_data[data_idx + g * feat_len];
|
||||
ele = ele > x ? ele : x;
|
||||
}
|
||||
output_data[i] = ele;
|
||||
}
|
||||
}
|
||||
template <typename T>
|
||||
__global__ void KernelMaxoutGrad(
|
||||
const int nthreads, const T* input_data, const T* output_data,
|
||||
const T* output_grad, T* input_grad, const int channels,
|
||||
const int input_height, const int input_width, int groups) {
|
||||
const int size = input_height * input_width * channels / groups;
|
||||
const int feat_len = input_height * input_width;
|
||||
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
||||
int offset = blockDim.x * gridDim.x;
|
||||
for (int i = index; i < nthreads; i += offset) {
|
||||
int batch_idx = i / size;
|
||||
int batch_offset = i % size;
|
||||
int channel_idx = batch_offset / feat_len;
|
||||
int feat_idx = batch_offset % feat_len;
|
||||
int data_idx =
|
||||
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
||||
int max_index = -1;
|
||||
bool continue_match = true;
|
||||
for (int g = 0; g < groups && continue_match; ++g) {
|
||||
if (input_data[data_idx + g * feat_len] == output_data[i]) {
|
||||
max_index = data_idx + g * feat_len;
|
||||
continue_match = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (max_index != -1) {
|
||||
input_grad[max_index] += output_grad[index];
|
||||
}
|
||||
}
|
||||
}
|
||||
/*
|
||||
* All tensors are in NCHW format.
|
||||
*/
|
||||
template <typename T>
|
||||
class MaxOutFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input, framework::Tensor * output,
|
||||
int groups) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_channels = input.dims()[1];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output->dims()[1];
|
||||
const int output_height = output->dims()[2];
|
||||
const int output_width = output->dims()[3];
|
||||
|
||||
const T* input_data = input.data<T>();
|
||||
T* output_data = output->mutable_data<T>(context.GetPlace());
|
||||
int nthreads = output->numel();
|
||||
int blocks = (nthreads + 1024 - 1) / 1024;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blocks, 1);
|
||||
|
||||
KernelMaxOut<
|
||||
T><<<grid, threads, 0,
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||
.stream()>>>(nthreads, input_data, input_channels,
|
||||
input_height, input_width, groups,
|
||||
output_data);
|
||||
}
|
||||
};
|
||||
/*
|
||||
* All tensors are in NCHW format.
|
||||
*/
|
||||
template <typename T>
|
||||
class MaxOutGradFunctor<platform::GPUPlace, T> {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad,
|
||||
int groups) {
|
||||
const int batch_size = input.dims()[0];
|
||||
const int input_channels = input.dims()[1];
|
||||
const int input_height = input.dims()[2];
|
||||
const int input_width = input.dims()[3];
|
||||
const int output_channels = output.dims()[1];
|
||||
const int output_height = output.dims()[2];
|
||||
const int output_width = output.dims()[3];
|
||||
|
||||
const T* input_data = input.data<T>();
|
||||
const T* output_data = output.data<T>();
|
||||
const T* output_grad_data = output_grad.data<T>();
|
||||
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
|
||||
int nthreads = output.numel();
|
||||
int blocks = (nthreads + 1024 - 1) / 1024;
|
||||
dim3 threads(1024, 1);
|
||||
dim3 grid(blocks, 1);
|
||||
|
||||
KernelMaxoutGrad<
|
||||
T><<<grid, threads, 0,
|
||||
reinterpret_cast<const platform::CUDADeviceContext&>(context)
|
||||
.stream()>>>(
|
||||
nthreads, input_data, output_data, output_grad_data, input_grad_data,
|
||||
input_channels, input_height, input_width, groups);
|
||||
}
|
||||
};
|
||||
|
||||
template class MaxOutGradFunctor<platform::GPUPlace, float>;
|
||||
template class MaxOutGradFunctor<platform::GPUPlace, double>;
|
||||
|
||||
template class MaxOutFunctor<platform::GPUPlace, float>;
|
||||
template class MaxOutFunctor<platform::GPUPlace, double>;
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,47 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
#include "paddle/framework/tensor.h"
|
||||
#include "paddle/platform/device_context.h"
|
||||
#include "paddle/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
#define FLT_MAX \
|
||||
__FLT_MAX__
|
||||
|
||||
template <typename Place, typename T>
|
||||
|
||||
class MaxOutFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input, framework::Tensor * output,
|
||||
int groups);
|
||||
};
|
||||
|
||||
template <typename Place, class T>
|
||||
class MaxOutGradFunctor {
|
||||
public:
|
||||
void operator()(const platform::DeviceContext& context,
|
||||
const framework::Tensor& input,
|
||||
framework::Tensor * input_grad,
|
||||
const framework::Tensor& output,
|
||||
const framework::Tensor& output_grad, int groups);
|
||||
};
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,104 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License. */
|
||||
|
||||
#include "paddle/operators/maxout_op.h"
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using framework::Tensor;
|
||||
|
||||
class MaxOutOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
MaxOutOpMaker(framework::OpProto* proto, framework::OpAttrChecker* op_checker)
|
||||
: OpProtoAndCheckerMaker(proto, op_checker) {
|
||||
AddInput("X",
|
||||
"(Tensor) The input tensor of maxout operator. "
|
||||
"The format of input tensor is NCHW. Where N is batch size, C is the "
|
||||
"number of channels, H and W is the height and width of feature.");
|
||||
AddOutput("Out",
|
||||
"(Tensor) The output tensor of maxout operator."
|
||||
"The format of output tensor is also NCHW."
|
||||
"Where N is batch size, C is "
|
||||
"the number of channels, H and W is the height and "
|
||||
"width of feature.");
|
||||
AddAttr<int>(
|
||||
"groups",
|
||||
R"DOC("Specifies how many groups the input tensor will be split"
|
||||
"in the channel dimension. And the number of output channel is "
|
||||
"the number of channels divided by groups.."
|
||||
)DOC");
|
||||
AddComment(R"DOC(
|
||||
Assumed the input shape is (N, Ci, H, W).
|
||||
The output shape is (N, Co, H, W). Then `Co = Ci / groups`.
|
||||
|
||||
math:
|
||||
y_{si+j} = \max_k x_{gsi + sk + j}
|
||||
g = groups
|
||||
s = input.size / num_channels
|
||||
0 \le i < num_channels / groups
|
||||
0 \le j < s
|
||||
0 \le k < groups
|
||||
|
||||
Please refer to Paper:
|
||||
- Maxout Networks: http://www.jmlr.org/proceedings/papers/v28/goodfellow13.pdf
|
||||
- Multi-digit Number Recognition from Street View \
|
||||
Imagery using Deep Convolutional Neural Networks: \
|
||||
https://arxiv.org/pdf/1312.6082v4.pdf
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
class MaxOutOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of MaxoutOp"
|
||||
"should not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput("Out"),
|
||||
"Output(Out) of MaxoutOp should not be null.");
|
||||
auto in_x_dims = ctx->GetInputDim("X");
|
||||
int groups = ctx->Attrs().Get<int>("groups");
|
||||
// check groups > 1
|
||||
PADDLE_ENFORCE_GT(
|
||||
groups, 1,
|
||||
"groups should be larger than 1 in maxoutop");
|
||||
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1] / groups});
|
||||
output_shape.push_back(in_x_dims[2]);
|
||||
output_shape.push_back(in_x_dims[3]);
|
||||
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
|
||||
}
|
||||
};
|
||||
|
||||
class MaxOutOpGrad : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
|
||||
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
|
||||
"Input(X@GRAD) should not be null.");
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
|
||||
}
|
||||
};
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP(maxout, ops::MaxOutOp, ops::MaxOutOpMaker, maxout_grad,
|
||||
ops::MaxOutOpGrad);
|
||||
REGISTER_OP_CPU_KERNEL(maxout, ops::MaxOutKernel<paddle::platform::CPUPlace,
|
||||
float>);
|
||||
REGISTER_OP_CPU_KERNEL(maxout_grad,
|
||||
ops::MaxOutGradKernel<paddle::platform::CPUPlace,
|
||||
float>);
|
@ -0,0 +1,25 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/operators/maxout_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
REGISTER_OP_GPU_KERNEL(maxout,
|
||||
ops::MaxOutKernel<paddle::platform::GPUPlace, float>,
|
||||
ops::MaxOutKernel<paddle::platform::GPUPlace, double>);
|
||||
REGISTER_OP_GPU_KERNEL(maxout_grad,
|
||||
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
|
||||
float>,
|
||||
ops::MaxOutGradKernel<paddle::platform::GPUPlace,
|
||||
double>);
|
@ -0,0 +1,62 @@
|
||||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/framework/op_registry.h"
|
||||
#include "paddle/operators/math/math_function.h"
|
||||
#include "paddle/operators/math/maxouting.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
using Tensor = framework::Tensor;
|
||||
|
||||
template <typename Place, typename T>
|
||||
class MaxOutKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* in_x = context.Input<Tensor>("X");
|
||||
Tensor* out = context.Output<Tensor>("Out");
|
||||
int groups = context.template Attr<int>("groups");
|
||||
|
||||
math::MaxOutFunctor<Place, T> maxout_forward;
|
||||
maxout_forward(context.device_context(), *in_x, out, groups);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Place, typename T>
|
||||
class MaxOutGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& context) const override {
|
||||
const Tensor* in_x = context.Input<Tensor>("X");
|
||||
const Tensor* out = context.Input<Tensor>("Out");
|
||||
const Tensor* out_grad =
|
||||
context.Input<Tensor>(framework::GradVarName("Out"));
|
||||
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
|
||||
int groups = context.template Attr<int>("groups");
|
||||
auto& device_ctx = context.device_context();
|
||||
math::SetConstant<Place, T> zero;
|
||||
if (in_x_grad) {
|
||||
in_x_grad->mutable_data<T>(context.GetPlace());
|
||||
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
|
||||
math::MaxOutGradFunctor<Place, T> maxout_backward;
|
||||
maxout_backward(context.device_context(), *in_x, in_x_grad, *out,
|
||||
*out_grad, groups);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,39 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from op_test import OpTest
|
||||
|
||||
|
||||
def maxout_forward_naive(input, groups):
|
||||
s0, s1, s2, s3 = input.shape
|
||||
return np.ndarray([s0, s1 / groups, groups, s2, s3], \
|
||||
buffer = input, dtype=input.dtype).max(axis=(2))
|
||||
|
||||
|
||||
class TestMaxOutOp(OpTest):
|
||||
def setUp(self):
|
||||
self.op_type = "maxout"
|
||||
self.init_test_case()
|
||||
input = np.random.random(self.shape).astype("float32")
|
||||
output = self.MaxOut_forward_naive(input, self.groups).astype("float32")
|
||||
|
||||
self.inputs = {'X': input}
|
||||
self.attrs = {'groups': self.groups}
|
||||
|
||||
self.outputs = {'Out': output.astype('float32')}
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(['X'], 'Out')
|
||||
|
||||
def init_test_case(self):
|
||||
self.MaxOut_forward_naive = maxout_forward_naive
|
||||
self.shape = [100, 6, 2, 2]
|
||||
self.groups=2
|
||||
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
Loading…
Reference in new issue