|
|
|
@ -11,8 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/math/pooling.h"
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
@ -27,9 +28,10 @@ template <typename PoolProcess, typename T>
|
|
|
|
|
class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
PoolProcess pool_process, framework::Tensor* output) {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
@ -63,11 +65,11 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
T ele = pool_process.initial();
|
|
|
|
|
for (int h = hstart; h < hend; ++h) {
|
|
|
|
|
for (int w = wstart; w < wend; ++w) {
|
|
|
|
|
pool_process.compute(ele, input_data[h * input_width + w]);
|
|
|
|
|
pool_process.compute(input_data[h * input_width + w], &ele);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int pool_size = (hend - hstart) * (wend - wstart);
|
|
|
|
|
pool_process.finalize(ele, (static_cast<T>(pool_size)));
|
|
|
|
|
pool_process.finalize(static_cast<T>(pool_size), &ele);
|
|
|
|
|
output_data[ph * output_width + pw] = ele;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -86,13 +88,12 @@ class Pool2dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
template <typename PoolProcess, class T>
|
|
|
|
|
class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output,
|
|
|
|
|
const framework::Tensor& output_grad, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
PoolProcess pool_grad_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
void operator()(
|
|
|
|
|
const platform::CPUDeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output, const framework::Tensor& output_grad,
|
|
|
|
|
const std::vector<int>& ksize, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_grad_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
@ -131,8 +132,8 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
input_data[h * input_width + w],
|
|
|
|
|
output_data[ph * output_width + pw],
|
|
|
|
|
output_grad_data[ph * output_width + pw],
|
|
|
|
|
input_grad_data[h * input_width + w],
|
|
|
|
|
static_cast<T>(scale));
|
|
|
|
|
static_cast<T>(scale),
|
|
|
|
|
input_grad_data + h * input_width + w);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -154,12 +155,11 @@ class Pool2dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
template <class T>
|
|
|
|
|
class MaxPool2dGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output,
|
|
|
|
|
const framework::Tensor& output_grad, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
void operator()(
|
|
|
|
|
const platform::CPUDeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output, const framework::Tensor& output_grad,
|
|
|
|
|
const std::vector<int>& ksize, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
@ -246,9 +246,10 @@ template <typename PoolProcess, class T>
|
|
|
|
|
class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
PoolProcess pool_process, framework::Tensor* output) {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_process,
|
|
|
|
|
framework::Tensor* output) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
|
const int input_height = input.dims()[3];
|
|
|
|
@ -293,14 +294,14 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
for (int h = hstart; h < hend; ++h) {
|
|
|
|
|
for (int w = wstart; w < wend; ++w) {
|
|
|
|
|
pool_process.compute(
|
|
|
|
|
ele,
|
|
|
|
|
input_data[(d * input_height + h) * input_width + w]);
|
|
|
|
|
input_data[(d * input_height + h) * input_width + w],
|
|
|
|
|
&ele);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
int pool_size =
|
|
|
|
|
(dend - dstart) * (hend - hstart) * (wend - wstart);
|
|
|
|
|
pool_process.finalize(ele, static_cast<T>(pool_size));
|
|
|
|
|
pool_process.finalize(static_cast<T>(pool_size), &ele);
|
|
|
|
|
output_data[output_idx] = ele;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -320,13 +321,12 @@ class Pool3dFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
template <typename PoolProcess, class T>
|
|
|
|
|
class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output,
|
|
|
|
|
const framework::Tensor& output_grad, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
PoolProcess pool_grad_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
void operator()(
|
|
|
|
|
const platform::CPUDeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output, const framework::Tensor& output_grad,
|
|
|
|
|
const std::vector<int>& ksize, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, PoolProcess pool_grad_process,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
|
const int input_height = input.dims()[3];
|
|
|
|
@ -379,8 +379,8 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
(pd * output_height + ph) * output_width + pw;
|
|
|
|
|
pool_grad_process.compute(
|
|
|
|
|
input_data[input_idx], output_data[output_idx],
|
|
|
|
|
output_grad_data[output_idx],
|
|
|
|
|
input_grad_data[input_idx], static_cast<T>(scale));
|
|
|
|
|
output_grad_data[output_idx], static_cast<T>(scale),
|
|
|
|
|
input_grad_data + input_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -404,12 +404,11 @@ class Pool3dGradFunctor<platform::CPUDeviceContext, PoolProcess, T> {
|
|
|
|
|
template <class T>
|
|
|
|
|
class MaxPool3dGradFunctor<platform::CPUDeviceContext, T> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output,
|
|
|
|
|
const framework::Tensor& output_grad, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
void operator()(
|
|
|
|
|
const platform::CPUDeviceContext& context, const framework::Tensor& input,
|
|
|
|
|
const framework::Tensor& output, const framework::Tensor& output_grad,
|
|
|
|
|
const std::vector<int>& ksize, const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
|
const int input_height = input.dims()[3];
|
|
|
|
@ -510,9 +509,10 @@ template <typename T1, typename T2>
|
|
|
|
|
class MaxPool2dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* output, framework::Tensor* mask) {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* output,
|
|
|
|
|
framework::Tensor* mask) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_height = input.dims()[2];
|
|
|
|
|
const int input_width = input.dims()[3];
|
|
|
|
@ -576,8 +576,9 @@ class MaxPool2dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& output_grad,
|
|
|
|
|
const framework::Tensor& mask, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
const framework::Tensor& mask, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input_grad->dims()[0];
|
|
|
|
|
const int input_height = input_grad->dims()[2];
|
|
|
|
@ -628,9 +629,10 @@ template <typename T1, typename T2>
|
|
|
|
|
class MaxPool3dWithIndexFunctor<platform::CPUDeviceContext, T1, T2> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& input, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* output, framework::Tensor* mask) {
|
|
|
|
|
const framework::Tensor& input, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings, framework::Tensor* output,
|
|
|
|
|
framework::Tensor* mask) {
|
|
|
|
|
const int batch_size = input.dims()[0];
|
|
|
|
|
const int input_depth = input.dims()[2];
|
|
|
|
|
const int input_height = input.dims()[3];
|
|
|
|
@ -708,8 +710,9 @@ class MaxPool3dWithIndexGradFunctor<platform::CPUDeviceContext, T1, T2> {
|
|
|
|
|
public:
|
|
|
|
|
void operator()(const platform::CPUDeviceContext& context,
|
|
|
|
|
const framework::Tensor& output_grad,
|
|
|
|
|
const framework::Tensor& mask, std::vector<int>& ksize,
|
|
|
|
|
std::vector<int>& strides, std::vector<int>& paddings,
|
|
|
|
|
const framework::Tensor& mask, const std::vector<int>& ksize,
|
|
|
|
|
const std::vector<int>& strides,
|
|
|
|
|
const std::vector<int>& paddings,
|
|
|
|
|
framework::Tensor* input_grad) {
|
|
|
|
|
const int batch_size = input_grad->dims()[0];
|
|
|
|
|
const int input_depth = input_grad->dims()[2];
|
|
|
|
|