Fix (According to the review)

tonyyang-svail-feed-op-desgin
chengduoZH 8 years ago
parent f6e69d7412
commit dfc8d3c1c1

@ -7,5 +7,3 @@ endif()
nv_test(math_function_test SRCS math_function_test.cc DEPS math_function tensor)
cc_test(im2col_test SRCS im2col_test.cc DEPS math_function tensor)
cc_test(pool_test_maxPool2d_test SRCS pool_test_maxPool2d.cc DEPS math_function tensor)
cc_test(pool_test_maxPool3d_test SRCS pool_test_maxPool3d.cc DEPS math_function tensor)

@ -1,154 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/operators/math/pooling.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
#include <stdlib.h>
#include <time.h>
#ifndef PADDLE_ONLY_CPU
template <typename PoolType, typename PoolGradType>
void testPool2d(paddle::platform::DeviceContext& context, PoolType pool_process,
PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool2dForwardFunctor<paddle::platform::GPUPlace,
PoolType, float>
pool2d_forward;
pool2d_forward(context, input, output, ksize, strides, paddings,
pool_process);
int times = 50;
clock_t start, finish;
double totaltime;
// Pool2dBackwardFunctor
start = clock();
for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool2dBackwardFunctor<paddle::platform::GPUPlace,
PoolGradType, float>
pool2d_backward;
pool2d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool2d_backward CopyFrom");
}
finish = clock();
totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
totaltime /= times;
std::cout << "\nPool3dBackwardFunctor: " << totaltime << "s" << std::endl;
// MaxPool3dBackwardFunctor
start = clock();
for (int j = 0; j < times; ++j) {
paddle::operators::math::MaxPool2dBackwardFunctor<
paddle::platform::GPUPlace, float>
maxpool2d_backward;
maxpool2d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings);
PADDLE_ENFORCE(
cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in maxpool2d_backward CopyFrom");
}
finish = clock();
totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
totaltime /= times;
std::cout << "\nMaxPool3dBackwardFunctor: " << totaltime << "s" << std::endl;
}
void test2dPool() {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
paddle::framework::Tensor input_tmp;
paddle::framework::Tensor output_tmp;
paddle::framework::Tensor input;
paddle::framework::Tensor input_grad;
paddle::framework::Tensor output;
paddle::framework::Tensor output_grad;
int batch = 32;
int channel = 32;
int input_height = 128;
int input_width = 128;
int in_len = batch * channel * input_height * input_width;
std::vector<int> ksize({3, 3});
std::vector<int> strides({1, 1});
std::vector<int> paddings({0, 0});
int output_height =
(input_height - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int output_width =
(input_width - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int output_len = output_height * output_width;
input_tmp.mutable_data<float>({batch, channel, input_height, input_width},
paddle::platform::CPUPlace());
output_tmp.mutable_data<float>({batch, channel, output_height, output_width},
paddle::platform::CPUPlace());
float* arr = new float[in_len];
auto* place = new paddle::platform::GPUPlace();
float* input_ptr = input_tmp.data<float>();
for (int i = 0; i < in_len; ++i) arr[i] = i; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, in_len * sizeof(float));
input.CopyFrom<float>(input_tmp, *place);
input_ptr = input_tmp.data<float>();
for (int i = 0; i < in_len; ++i) arr[i] = 0;
memcpy(input_ptr, arr, in_len * sizeof(float));
input_grad.CopyFrom<float>(input_tmp, *place);
// output
input_ptr = output_tmp.data<float>();
for (int i = 0; i < output_len; ++i)
arr[i] = 0; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, output_len * sizeof(float));
output.CopyFrom<float>(input_tmp, *place);
// output
input_ptr = output_tmp.data<float>();
for (int i = 0; i < output_len; ++i)
arr[i] = 1; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, output_len * sizeof(float));
output_grad.CopyFrom<float>(input_tmp, *place);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process;
paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
testPool2d<paddle::operators::math::pool::maxPool<float>,
paddle::operators::math::pool::maxPoolGrad<float>>(
*context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
}
int main() {
// testPool3d<paddle::platform::CPUPlace>();
test2dPool();
// testPool3d<paddle::platform::GPUPlace>();
}
#endif

@ -1,157 +0,0 @@
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include <gtest/gtest.h>
#include "paddle/operators/math/pooling.h"
#include "paddle/memory/memcpy.h"
#include "paddle/platform/enforce.h"
#include <stdlib.h>
#include <time.h>
#ifndef PADDLE_ONLY_CPU
template <typename PoolType, typename PoolGradType>
void testPool3d(paddle::platform::DeviceContext& context, PoolType pool_process,
PoolGradType poolGrad_process, paddle::framework::Tensor& input,
paddle::framework::Tensor& input_grad,
paddle::framework::Tensor& output,
paddle::framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings) {
paddle::operators::math::Pool3dForwardFunctor<paddle::platform::GPUPlace,
PoolType, float>
pool3d_forward;
pool3d_forward(context, input, output, ksize, strides, paddings,
pool_process);
int times = 50;
clock_t start, finish;
double totaltime;
// Pool3dBackwardFunctor
start = clock();
for (int i = 0; i < times; ++i) {
paddle::operators::math::Pool3dBackwardFunctor<paddle::platform::GPUPlace,
PoolGradType, float>
pool3d_backward;
pool3d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings, poolGrad_process);
PADDLE_ENFORCE(cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in pool3d_backward CopyFrom");
}
finish = clock();
totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
totaltime /= times;
std::cout << "\nPool3dBackwardFunctor: " << totaltime << "s" << std::endl;
// MaxPool3dBackwardFunctor
start = clock();
for (int j = 0; j < times; ++j) {
paddle::operators::math::MaxPool3dBackwardFunctor<
paddle::platform::GPUPlace, float>
maxpool3d_backward;
maxpool3d_backward(context, input, input_grad, output, output_grad, ksize,
strides, paddings);
PADDLE_ENFORCE(
cudaStreamSynchronize(0),
"cudaStreamSynchronize failed in maxpool3d_backward CopyFrom");
}
finish = clock();
totaltime = (double)(finish - start) / CLOCKS_PER_SEC;
totaltime /= times;
std::cout << "\nMaxPool3dBackwardFunctor: " << totaltime << "s" << std::endl;
}
void test3dPool() {
using paddle::platform::DeviceContext;
using paddle::platform::CUDADeviceContext;
using paddle::platform::GPUPlace;
paddle::framework::Tensor input_tmp;
paddle::framework::Tensor output_tmp;
paddle::framework::Tensor input;
paddle::framework::Tensor input_grad;
paddle::framework::Tensor output;
paddle::framework::Tensor output_grad;
int batch = 32;
int channel = 4;
int input_depth = 4;
int input_height = 128;
int input_width = 128;
int in_len = batch * channel * input_depth * input_height * input_width;
std::vector<int> ksize({3, 3, 3});
std::vector<int> strides({2, 2, 2});
std::vector<int> paddings({1, 1, 1});
int output_depth =
(input_depth - ksize[0] + 2 * paddings[0]) / strides[0] + 1;
int output_height =
(input_height - ksize[1] + 2 * paddings[1]) / strides[1] + 1;
int output_width =
(input_width - ksize[2] + 2 * paddings[2]) / strides[2] + 1;
int output_len = output_depth * output_height * output_width;
input_tmp.mutable_data<float>(
{batch, channel, input_depth, input_height, input_width},
paddle::platform::CPUPlace());
output_tmp.mutable_data<float>(
{batch, channel, output_depth, output_height, output_width},
paddle::platform::CPUPlace());
float* arr = new float[in_len];
auto* place = new paddle::platform::GPUPlace();
// input
float* input_ptr = input_tmp.data<float>();
for (int i = 0; i < in_len; ++i) arr[i] = i; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, in_len * sizeof(float));
input.CopyFrom<float>(input_tmp, *place);
// input_grad
input_ptr = input_tmp.data<float>();
for (int i = 0; i < in_len; ++i) arr[i] = 0;
memcpy(input_ptr, arr, in_len * sizeof(float));
input_grad.CopyFrom<float>(input_tmp, *place);
// output
input_ptr = output_tmp.data<float>();
for (int i = 0; i < output_len; ++i)
arr[i] = 0; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, output_len * sizeof(float));
output.CopyFrom<float>(input_tmp, *place);
// output_grad
input_ptr = output_tmp.data<float>();
for (int i = 0; i < output_len; ++i)
arr[i] = 1; // rand() / double(RAND_MAX/2);
memcpy(input_ptr, arr, output_len * sizeof(float));
output_grad.CopyFrom<float>(input_tmp, *place);
paddle::platform::DeviceContext* context =
new paddle::platform::CUDADeviceContext(paddle::platform::GPUPlace());
paddle::operators::math::pool::maxPool<float> pool_process;
paddle::operators::math::pool::maxPoolGrad<float> poolGrad_process;
testPool3d<paddle::operators::math::pool::maxPool<float>,
paddle::operators::math::pool::maxPoolGrad<float>>(
*context, pool_process, poolGrad_process, input, input_grad, output,
output_grad, ksize, strides, paddings);
}
int main() { test3dPool(); }
#endif

@ -19,12 +19,12 @@ namespace operators {
namespace math {
template <typename PoolProcess, typename T>
class Pool2dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
std::vector<int>& paddings, PoolProcess pool_compute) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
@ -54,14 +54,14 @@ class Pool2dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
T ele = pool_process.initial();
T ele = pool_compute.initial();
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(ele, input_data[h * input_width + w]);
pool_compute.compute(ele, input_data[h * input_width + w]);
}
}
int pool_size = (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, (static_cast<T>(pool_size)));
pool_compute.finalize(ele, (static_cast<T>(pool_size)));
output_data[ph * output_width + pw] = ele;
}
}
@ -73,14 +73,14 @@ class Pool2dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
};
template <typename PoolProcess, class T>
class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process) {
PoolProcess pool_compute) {
const int batch_size = input.dims()[0];
const int input_height = input.dims()[2];
const int input_width = input.dims()[3];
@ -115,12 +115,11 @@ class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.gradProcess(
input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(scale));
pool_compute.compute(input_data[h * input_width + w],
output_data[ph * output_width + pw],
output_grad_data[ph * output_width + pw],
input_grad_data[h * input_width + w],
static_cast<T>(scale));
}
}
}
@ -135,7 +134,7 @@ class Pool2dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
};
template <class T>
class MaxPool2dBackwardFunctor<platform::CPUPlace, T> {
class MaxPool2dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
@ -195,37 +194,33 @@ class MaxPool2dBackwardFunctor<platform::CPUPlace, T> {
}
};
template class MaxPool2dBackwardFunctor<platform::CPUPlace, float>;
// template class MaxPool2dBackwardFunctor<platform::CPUPlace, double>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPool<float>, float>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPoolGrad<float>,
float>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPoolGrad<float>,
float>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool2dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPool<double>, double>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPoolGrad<double>,
double>;
template class Pool2dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPoolGrad<double>,
double>;
template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>;
template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
template <typename PoolProcess, class T>
class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process) {
std::vector<int>& paddings, PoolProcess pool_compute) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
@ -265,11 +260,11 @@ class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_process.initial();
T ele = pool_compute.initial();
for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) {
pool_process.process(
pool_compute.compute(
ele,
input_data[(d * input_height + h) * input_width + w]);
}
@ -277,7 +272,7 @@ class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
}
int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart);
pool_process.finalize(ele, static_cast<T>(pool_size));
pool_compute.finalize(ele, static_cast<T>(pool_size));
output_data[output_idx] = ele;
}
}
@ -290,14 +285,14 @@ class Pool3dForwardFunctor<platform::CPUPlace, PoolProcess, T> {
};
template <typename PoolProcess, class T>
class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process) {
PoolProcess pool_compute) {
const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2];
const int input_height = input.dims()[3];
@ -348,7 +343,7 @@ class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
int input_idx = (d * input_height + h) * input_width + w;
int output_idx =
(pd * output_height + ph) * output_width + pw;
pool_process.gradProcess(
pool_compute.compute(
input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx],
input_grad_data[input_idx], static_cast<T>(scale));
@ -368,7 +363,7 @@ class Pool3dBackwardFunctor<platform::CPUPlace, PoolProcess, T> {
};
template <class T>
class MaxPool3dBackwardFunctor<platform::CPUPlace, T> {
class MaxPool3dGradFunctor<platform::CPUPlace, T> {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
@ -442,29 +437,25 @@ class MaxPool3dBackwardFunctor<platform::CPUPlace, T> {
}
};
template class MaxPool3dBackwardFunctor<platform::CPUPlace, float>;
// template class MaxPool3dBackwardFunctor<platform::CPUPlace, double>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<float>, float>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPool<float>, float>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPoolGrad<float>,
float>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPoolGrad<float>,
float>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPool<double>, double>;
template class Pool3dForwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPool<double>, double>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::maxPoolGrad<double>,
double>;
template class Pool3dBackwardFunctor<
platform::CPUPlace, paddle::operators::math::pool::avgPoolGrad<double>,
double>;
template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>;
template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>;
template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>;
} // namespace math
} // namespace operators
} // namespace paddle

File diff suppressed because it is too large Load Diff

@ -21,17 +21,15 @@ limitations under the License. */
namespace paddle {
namespace operators {
namespace math {
//////////////////////
#define FLT_MAX __FLT_MAX__
/////////////////////
namespace pool {
template <class T>
class maxPool {
public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void process(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
DEVICE inline void finalize(T& y, const T& poo_size) {}
};
@ -39,14 +37,14 @@ template <class T>
class avgPool {
public:
DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void process(T& y, const T& x) { y += x; }
DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
};
template <class T>
class maxPoolGrad {
public:
DEVICE inline void gradProcess(const T& x, const T& y, const T& dy, T& dx,
T scale) {
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += dy * (x == y);
}
};
@ -54,35 +52,34 @@ class maxPoolGrad {
template <class T>
class avgPoolGrad {
public:
DEVICE inline void gradProcess(const T& x, const T& y, const T& dy, T& dx,
T scale) {
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) {
dx += (scale * dy);
}
};
} // namespace pool
template <typename Place, typename PoolProcess, typename T>
class Pool2dForwardFunctor {
class Pool2dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process);
std::vector<int>& paddings, PoolProcess pool_compute);
};
template <typename Place, typename PoolProcess, typename T>
class Pool2dBackwardFunctor {
class Pool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process);
PoolProcess pool_compute);
};
template <typename Place, class T>
class MaxPool2dBackwardFunctor {
class MaxPool2dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
@ -92,27 +89,27 @@ class MaxPool2dBackwardFunctor {
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dForwardFunctor {
class Pool3dFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_process);
std::vector<int>& paddings, PoolProcess pool_compute);
};
template <typename Place, typename PoolProcess, typename T>
class Pool3dBackwardFunctor {
class Pool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,
const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_process);
PoolProcess pool_compute);
};
template <typename Place, class T>
class MaxPool3dBackwardFunctor {
class MaxPool3dGradFunctor {
public:
void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& input_grad,

@ -17,7 +17,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
int outputSize_pool(int input_size, int filter_size, int padding, int stride) {
int OutputSizePool(int input_size, int filter_size, int padding, int stride) {
int output_size = (input_size - filter_size + 2 * padding) / stride + 1;
return output_size;
}
@ -33,7 +33,7 @@ class PoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_NOT_NULL(ctx.OutputVar("Out"),
"Out(Output) of Pooling should not be null.");
auto in_X = ctx.Input<Tensor>("X");
auto in_x = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out");
int global_pooling = Attr<int>("globalPooling");
std::string pooling_type = Attr<std::string>("poolingType");
@ -43,30 +43,25 @@ class PoolOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(pooling_type == "max" || pooling_type == "avg",
"pooling_type should be 'max' or 'avg'");
PADDLE_ENFORCE(in_X->dims().size() == 4 || in_X->dims().size() == 5,
PADDLE_ENFORCE(in_x->dims().size() == 4 || in_x->dims().size() == 5,
"Pooling intput should be 4-D or 5-D");
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
"Pooling size should be 2 elements. or 3 elements.");
PADDLE_ENFORCE_EQ(ksize.size(), strides.size(),
"strides size and pooling size should be the same.");
PADDLE_ENFORCE_EQ(ksize.size(), paddings.size(),
"paddings size and pooling size should be the same.");
if (global_pooling == 1) {
ksize.resize(static_cast<size_t>(in_X->dims().size()) - 2);
ksize.resize(static_cast<size_t>(in_x->dims().size()) - 2);
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_X->dims()[i + 2]);
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
if (ksize.size() == 2) {
PADDLE_ENFORCE_EQ(strides.size(), 2,
"Pool2DOp strides size should be 2 elements.");
PADDLE_ENFORCE_EQ(paddings.size(), 2,
"Pool2DOp paddings size should be 2 elements");
} else {
PADDLE_ENFORCE_EQ(strides.size(), 3,
"Pool3DOp strides should be 3 elements.");
PADDLE_ENFORCE_EQ(paddings.size(), 3,
"Pool3DOp paddings should be 3 elements.");
}
std::vector<int64_t> output_shape({in_X->dims()[0], in_X->dims()[1]});
std::vector<int64_t> output_shape({in_x->dims()[0], in_x->dims()[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(outputSize_pool(in_X->dims()[i + 2], ksize[i],
paddings[i], strides[i]));
output_shape.push_back(OutputSizePool(in_x->dims()[i + 2], ksize[i],
paddings[i], strides[i]));
}
out->Resize(framework::make_ddim(output_shape));
}
@ -78,9 +73,16 @@ class PoolOpGrad : public framework::OperatorWithKernel {
protected:
void InferShape(const framework::InferShapeContext &ctx) const override {
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"),
"X(Input) of Pooling should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Out"),
"Out(Output) of Pooling should not be null.");
PADDLE_ENFORCE_NOT_NULL(ctx.Output<Tensor>(framework::GradVarName("X")),
"Input@Grad of Pooling should not be null.");
auto in = ctx.Input<Tensor>("X");
auto d_in = ctx.Output<Tensor>(framework::GradVarName("X"));
if (d_in) d_in->Resize(in->dims());
d_in->Resize(in->dims());
}
};
@ -92,7 +94,7 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"X",
"The input tensor of pooling operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of image.");
"number of channels, H and W is the height and width of feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCHW.");
@ -166,7 +168,7 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of input tensor is NCDHW. Where N is batch size, C is "
"the "
"number of channels, D, H and W is the depth, height and width of "
"image.");
"feature.");
AddOutput("Out",
"The output tensor of pooling operator."
"The format of output tensor is also NCDHW.");

@ -1,15 +1,16 @@
/* Copyright (c) 2016 PaddlePaddle Authors All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
http://www.apache.org/licenses/LICENSE-2.0
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/pool_op.h"

@ -28,7 +28,7 @@ template <typename Place, typename T>
class PoolKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_X = context.Input<Tensor>("X");
const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out");
int global_pooling = context.Attr<int>("globalPooling");
@ -38,43 +38,43 @@ class PoolKernel : public framework::OpKernel {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling == 1) {
for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = in_X->dims()[i + 2];
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
}
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::Pool2dForwardFunctor<
Place, paddle::operators::math::pool::maxPool<T>, T>
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::maxPool<T>, T>
pool2d_forward;
paddle::operators::math::pool::maxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_X, *out, ksize, strides,
paddle::operators::math::maxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dForwardFunctor<
Place, paddle::operators::math::pool::avgPool<T>, T>
paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::avgPool<T>, T>
pool2d_forward;
paddle::operators::math::pool::avgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_X, *out, ksize, strides,
paddle::operators::math::avgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::Pool3dForwardFunctor<
Place, paddle::operators::math::pool::maxPool<T>, T>
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::maxPool<T>, T>
pool3d_forward;
paddle::operators::math::pool::maxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_X, *out, ksize, strides,
paddle::operators::math::maxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dForwardFunctor<
Place, paddle::operators::math::pool::avgPool<T>, T>
paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::avgPool<T>, T>
pool3d_forward;
paddle::operators::math::pool::avgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_X, *out, ksize, strides,
paddle::operators::math::avgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process);
}
} break;
@ -86,11 +86,11 @@ template <typename Place, typename T>
class PoolGradKernel : public framework::OpKernel {
public:
void Compute(const framework::ExecutionContext& context) const override {
const Tensor* in_X = context.Input<Tensor>("X");
const Tensor* in_x = context.Input<Tensor>("X");
const Tensor* out = context.Input<Tensor>("Out");
const Tensor* out_grad =
context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_X_grad = context.Output<Tensor>(framework::GradVarName("X"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
int global_pooling = context.Attr<int>("globalPooling");
std::string pooling_type = context.Attr<std::string>("poolingType");
@ -99,43 +99,44 @@ class PoolGradKernel : public framework::OpKernel {
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling == 1) {
for (size_t i = 0; i < ksize.size(); ++i) ksize[i] = in_X->dims()[i + 2];
for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
}
if (in_X_grad) {
in_X_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_X_grad);
if (in_x_grad) {
in_x_grad->mutable_data<T>(context.GetPlace());
auto temp = framework::EigenVector<T>::Flatten(*in_x_grad);
temp.device(context.GetEigenDevice<Place>()) =
temp.constant(static_cast<T>(0));
switch (ksize.size()) {
case 2: {
if (pooling_type == "max") {
paddle::operators::math::MaxPool2dBackwardFunctor<Place, T>
paddle::operators::math::MaxPool2dGradFunctor<Place, T>
pool2d_backward;
pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out,
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool2dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T>
pool2d_backward;
paddle::operators::math::pool::avgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_X, *in_X_grad, *out,
paddle::operators::math::avgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
} break;
case 3: {
if (pooling_type == "max") {
paddle::operators::math::MaxPool3dBackwardFunctor<Place, T>
paddle::operators::math::MaxPool3dGradFunctor<Place, T>
pool3d_backward;
pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out,
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") {
paddle::operators::math::Pool3dBackwardFunctor<
Place, paddle::operators::math::pool::avgPoolGrad<T>, T>
paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T>
pool3d_backward;
paddle::operators::math::pool::avgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_X, *in_X_grad, *out,
paddle::operators::math::avgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process);
}
} break;

Loading…
Cancel
Save