|
|
|
@ -24,21 +24,20 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
|
|
|
|
|
T* output_data, const int channels,
|
|
|
|
|
const int input_height, const int input_width,
|
|
|
|
|
int groups, MaxOutProcess maxout_process) {
|
|
|
|
|
int size = input_height * input_width * channels / groups;
|
|
|
|
|
int featLen = input_height * input_width;
|
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int batch_idx = index / size;
|
|
|
|
|
int i = index % size;
|
|
|
|
|
int channel_idx = i / featLen;
|
|
|
|
|
int feat_idx = i % featLen;
|
|
|
|
|
int batch_offset = index % size;
|
|
|
|
|
int channel_idx = batch_offset / feat_len;
|
|
|
|
|
int feat_idx = batch_offset % feat_len;
|
|
|
|
|
int data_idx =
|
|
|
|
|
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
|
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
|
T ele = maxout_process.initial();
|
|
|
|
|
for (int g = 0; g < groups; g++) {
|
|
|
|
|
maxout_process.compute(ele, input_data[data_idx + g * featLen]);
|
|
|
|
|
for (int g = 0; g < groups; ++g) {
|
|
|
|
|
maxout_process.compute(ele, input_data[data_idx + g * feat_len]);
|
|
|
|
|
}
|
|
|
|
|
maxout_process.finalize(ele, (static_cast<T>(groups)));
|
|
|
|
|
output_data[index] = ele;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -47,21 +46,21 @@ __global__ void KernelMaxoutGrad(
|
|
|
|
|
const int nthreads, const T* input_data, const T* output_data,
|
|
|
|
|
const T* output_grad, T* input_grad, const int channels,
|
|
|
|
|
const int input_height, const int input_width, int groups) {
|
|
|
|
|
int size = input_height * input_width * channels / groups;
|
|
|
|
|
int featLen = input_height * input_width;
|
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
|
for (int index = blockIdx.x * blockDim.x + threadIdx.x; index < nthreads;
|
|
|
|
|
index += blockDim.x * gridDim.x) {
|
|
|
|
|
int batch_idx = index / size;
|
|
|
|
|
int i = index % size;
|
|
|
|
|
int channel_idx = i / featLen;
|
|
|
|
|
int feat_idx = i % featLen;
|
|
|
|
|
int batch_offset = index % size;
|
|
|
|
|
int channel_idx = batch_offset / feat_len;
|
|
|
|
|
int feat_idx = batch_offset % feat_len;
|
|
|
|
|
int data_idx =
|
|
|
|
|
(batch_idx * size + channel_idx * featLen) * groups + feat_idx;
|
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
|
int maxIndex = -1;
|
|
|
|
|
bool stop = false;
|
|
|
|
|
for (int g = 0; g < groups && !stop; g++) {
|
|
|
|
|
if (input_data[data_idx + g * featLen] == output_data[index]) {
|
|
|
|
|
maxIndex = data_idx + g * featLen;
|
|
|
|
|
if (input_data[data_idx + g * feat_len] == output_data[index]) {
|
|
|
|
|
maxIndex = data_idx + g * feat_len;
|
|
|
|
|
stop = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -73,28 +72,25 @@ __global__ void KernelMaxoutGrad(
|
|
|
|
|
}
|
|
|
|
|
/*
|
|
|
|
|
* All tensors are in NCHW format.
|
|
|
|
|
* Ksize, strides, paddings are two elements. These two elements represent
|
|
|
|
|
* height and width, respectively.
|
|
|
|
|
*/
|
|
|
|
|
template <typename MaxOutProcess, typename T>
|
|
|
|
|
class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::DeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, framework::Tensor& output,
|
|
|
|
|
int groups, int num_channels,
|
|
|
|
|
const framework::Tensor& input, framework::Tensor * output,
|
|
|
|
|
int groups,
|
|
|
|
|
MaxOutProcess maxout_process) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
|
const int output_channels = num_channels / groups;
|
|
|
|
|
const int output_height = output.dims()[2];
|
|
|
|
|
const int output_width = output.dims()[3];
|
|
|
|
|
const int output_channels = output->dims()[1];
|
|
|
|
|
const int output_height = output->dims()[2];
|
|
|
|
|
const int output_width = output->dims()[3];
|
|
|
|
|
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
|
T* output_data = output.mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int nthreads = batch_size * output_channels * output_height * output_width;
|
|
|
|
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
int nthreads = output->numel();
|
|
|
|
|
int blocks = (nthreads + 1024 - 1) / 1024;
|
|
|
|
|
dim3 threads(1024, 1);
|
|
|
|
|
dim3 grid(blocks, 1);
|
|
|
|
@ -110,8 +106,6 @@ class MaxOutFunctor<platform::GPUPlace, MaxOutProcess, T> {
|
|
|
|
|
};
|
|
|
|
|
/*
|
|
|
|
|
* All tensors are in NCHW format.
|
|
|
|
|
* Ksize, strides, paddings are two elements. These two elements represent
|
|
|
|
|
* height and width, respectively.
|
|
|
|
|
*/
|
|
|
|
|
template <typename T>
|
|
|
|
|
class MaxOutGradFunctor<platform::GPUPlace, T> {
|
|
|
|
@ -120,7 +114,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
|
|
|
|
|
const framework::Tensor& input, framework::Tensor& input_grad,
|
|
|
|
|
const framework::Tensor& output,
|
|
|
|
|
const framework::Tensor& output_grad,
|
|
|
|
|
int groups, int num_channels) {
|
|
|
|
|
int groups) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
@ -133,8 +127,7 @@ class MaxOutGradFunctor<platform::GPUPlace, T> {
|
|
|
|
|
const T* output_data = output.data<T>();
|
|
|
|
|
const T* output_grad_data = output_grad.data<T>();
|
|
|
|
|
T* input_grad_data = input_grad.mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int nthreads = batch_size * output_channels * output_height * output_width;
|
|
|
|
|
int nthreads = output.numel();
|
|
|
|
|
int blocks = (nthreads + 1024 - 1) / 1024;
|
|
|
|
|
dim3 threads(1024, 1);
|
|
|
|
|
dim3 grid(blocks, 1);
|
|
|
|
@ -152,9 +145,9 @@ template class MaxOutGradFunctor<platform::GPUPlace, float>;
|
|
|
|
|
template class MaxOutGradFunctor<platform::GPUPlace, double>;
|
|
|
|
|
|
|
|
|
|
template class MaxOutFunctor<platform::GPUPlace,
|
|
|
|
|
paddle::operators::math::MaxOut<float>, float>;
|
|
|
|
|
math::MaxOut<float>, float>;
|
|
|
|
|
template class MaxOutFunctor<platform::GPUPlace,
|
|
|
|
|
paddle::operators::math::MaxOut<double>, double>;
|
|
|
|
|
math::MaxOut<double>, double>;
|
|
|
|
|
|
|
|
|
|
} // namespace math
|
|
|
|
|
} // namespace operators
|
|
|
|
|