|
|
@ -22,8 +22,8 @@ namespace math {
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
|
|
|
|
__global__ void KernelMaxOut(const int nthreads, const T* input_data,
|
|
|
|
const int channels, const int input_height,
|
|
|
|
const int channels, const int input_height,
|
|
|
|
const int input_width, int groups,
|
|
|
|
const int input_width, const int groups,
|
|
|
|
T* output_data) {
|
|
|
|
const int axis, T* output_data) {
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
@ -31,13 +31,22 @@ __global__ void KernelMaxOut(const int nthreads, const T* input_data,
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
int batch_idx = i / size;
|
|
|
|
int batch_idx = i / size;
|
|
|
|
int batch_offset = i % size;
|
|
|
|
int batch_offset = i % size;
|
|
|
|
int channel_idx = batch_offset / feat_len;
|
|
|
|
int channel_idx, feat_idx, data_idx;
|
|
|
|
int feat_idx = batch_offset % feat_len;
|
|
|
|
if (axis == 1) {
|
|
|
|
int data_idx =
|
|
|
|
channel_idx = batch_offset / feat_len;
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
feat_idx = batch_offset % feat_len;
|
|
|
|
|
|
|
|
data_idx =
|
|
|
|
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
channel_idx = batch_offset % channels;
|
|
|
|
|
|
|
|
feat_idx = batch_offset / channels;
|
|
|
|
|
|
|
|
data_idx =
|
|
|
|
|
|
|
|
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
|
|
|
|
|
|
|
|
}
|
|
|
|
T ele = static_cast<T>(-FLT_MAX);
|
|
|
|
T ele = static_cast<T>(-FLT_MAX);
|
|
|
|
for (int g = 0; g < groups; ++g) {
|
|
|
|
for (int g = 0; g < groups; ++g) {
|
|
|
|
T x = input_data[data_idx + g * feat_len];
|
|
|
|
int idx_offset = (axis == 1 ? g * feat_len : g);
|
|
|
|
|
|
|
|
T x = input_data[data_idx + idx_offset];
|
|
|
|
ele = ele > x ? ele : x;
|
|
|
|
ele = ele > x ? ele : x;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
output_data[i] = ele;
|
|
|
|
output_data[i] = ele;
|
|
|
@ -48,7 +57,7 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
|
|
|
|
const T* output_data, const T* output_grad,
|
|
|
|
const T* output_data, const T* output_grad,
|
|
|
|
T* input_grad, const int channels,
|
|
|
|
T* input_grad, const int channels,
|
|
|
|
const int input_height, const int input_width,
|
|
|
|
const int input_height, const int input_width,
|
|
|
|
int groups) {
|
|
|
|
const int groups, const int axis) {
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
const int size = input_height * input_width * channels / groups;
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
const int feat_len = input_height * input_width;
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
|
int index = blockIdx.x * blockDim.x + threadIdx.x;
|
|
|
@ -56,15 +65,24 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
for (int i = index; i < nthreads; i += offset) {
|
|
|
|
int batch_idx = i / size;
|
|
|
|
int batch_idx = i / size;
|
|
|
|
int batch_offset = i % size;
|
|
|
|
int batch_offset = i % size;
|
|
|
|
int channel_idx = batch_offset / feat_len;
|
|
|
|
int channel_idx, feat_idx, data_idx;
|
|
|
|
int feat_idx = batch_offset % feat_len;
|
|
|
|
if (axis == 1) {
|
|
|
|
int data_idx =
|
|
|
|
channel_idx = batch_offset / feat_len;
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
feat_idx = batch_offset % feat_len;
|
|
|
|
|
|
|
|
data_idx =
|
|
|
|
|
|
|
|
(batch_idx * size + channel_idx * feat_len) * groups + feat_idx;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
channel_idx = batch_offset % channels;
|
|
|
|
|
|
|
|
feat_idx = batch_offset / channels;
|
|
|
|
|
|
|
|
data_idx =
|
|
|
|
|
|
|
|
(batch_idx * size + feat_idx * channels + channel_idx) * groups;
|
|
|
|
|
|
|
|
}
|
|
|
|
int max_index = -1;
|
|
|
|
int max_index = -1;
|
|
|
|
bool continue_match = true;
|
|
|
|
bool continue_match = true;
|
|
|
|
for (int g = 0; g < groups && continue_match; ++g) {
|
|
|
|
for (int g = 0; g < groups && continue_match; ++g) {
|
|
|
|
if (input_data[data_idx + g * feat_len] == output_data[i]) {
|
|
|
|
int idx_offset = (axis == 1 ? g * feat_len : g);
|
|
|
|
max_index = data_idx + g * feat_len;
|
|
|
|
if (input_data[data_idx + idx_offset] == output_data[i]) {
|
|
|
|
|
|
|
|
max_index = data_idx + idx_offset;
|
|
|
|
continue_match = false;
|
|
|
|
continue_match = false;
|
|
|
|
break;
|
|
|
|
break;
|
|
|
|
}
|
|
|
|
}
|
|
|
@ -75,21 +93,19 @@ __global__ void KernelMaxoutGrad(const int nthreads, const T* input_data,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* All tensors are in NCHW format.
|
|
|
|
* All tensors are in NCHW or NHWC format.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class MaxOutFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
class MaxOutFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
const framework::Tensor& input, framework::Tensor* output,
|
|
|
|
const framework::Tensor& input, framework::Tensor* output,
|
|
|
|
int groups) {
|
|
|
|
const int groups, const int axis) {
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
const int input_channels = input.dims()[axis];
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
|
|
|
|
const int output_channels = output->dims()[1];
|
|
|
|
const int output_channels = output->dims()[axis];
|
|
|
|
const int output_height = output->dims()[2];
|
|
|
|
|
|
|
|
const int output_width = output->dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
|
|
|
T* output_data = output->mutable_data<T>(context.GetPlace());
|
|
|
@ -100,11 +116,11 @@ class MaxOutFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
|
|
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
KernelMaxOut<T><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
nthreads, input_data, input_channels, input_height, input_width, groups,
|
|
|
|
nthreads, input_data, input_channels, input_height, input_width, groups,
|
|
|
|
output_data);
|
|
|
|
axis, output_data);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
* All tensors are in NCHW format.
|
|
|
|
* All tensors are in NCHW or NHWC format.
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
@ -112,14 +128,13 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
void operator()(const platform::CUDADeviceContext& context,
|
|
|
|
const framework::Tensor& input, framework::Tensor* input_grad,
|
|
|
|
const framework::Tensor& input, framework::Tensor* input_grad,
|
|
|
|
const framework::Tensor& output,
|
|
|
|
const framework::Tensor& output,
|
|
|
|
const framework::Tensor& output_grad, int groups) {
|
|
|
|
const framework::Tensor& output_grad, const int groups,
|
|
|
|
|
|
|
|
const int axis) {
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
const int input_channels = input.dims()[1];
|
|
|
|
const int input_channels = input.dims()[axis];
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
const int input_height = (axis == 1 ? input.dims()[2] : input.dims()[1]);
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
const int input_width = (axis == 1 ? input.dims()[3] : input.dims()[2]);
|
|
|
|
const int output_channels = output.dims()[1];
|
|
|
|
const int output_channels = output.dims()[axis];
|
|
|
|
const int output_height = output.dims()[2];
|
|
|
|
|
|
|
|
const int output_width = output.dims()[3];
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
const T* input_data = input.data<T>();
|
|
|
|
const T* output_data = output.data<T>();
|
|
|
|
const T* output_data = output.data<T>();
|
|
|
@ -132,7 +147,7 @@ class MaxOutGradFunctor<platform::CUDADeviceContext, T> {
|
|
|
|
|
|
|
|
|
|
|
|
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
KernelMaxoutGrad<T><<<grid, threads, 0, context.stream()>>>(
|
|
|
|
nthreads, input_data, output_data, output_grad_data, input_grad_data,
|
|
|
|
nthreads, input_data, output_data, output_grad_data, input_grad_data,
|
|
|
|
input_channels, input_height, input_width, groups);
|
|
|
|
input_channels, input_height, input_width, groups, axis);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|