remove custom attr checker and fix code format

tonyyang-svail-feed-op-desgin
chengduoZH 8 years ago
parent 3c0f079333
commit e1e3859e88

@ -24,7 +24,7 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) { std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
@ -54,14 +54,15 @@ class Pool2dFunctor<platform::CPUPlace, PoolProcess, T> {
int wstart = pw * stride_width - padding_width; int wstart = pw * stride_width - padding_width;
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
T ele = pool_compute.initial();
T ele = pool_process.initial();
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
pool_compute.compute(ele, input_data[h * input_width + w]); pool_process.compute(ele, input_data[h * input_width + w]);
} }
} }
int pool_size = (hend - hstart) * (wend - wstart); int pool_size = (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, (static_cast<T>(pool_size))); pool_process.finalize(ele, (static_cast<T>(pool_size)));
output_data[ph * output_width + pw] = ele; output_data[ph * output_width + pw] = ele;
} }
} }
@ -80,7 +81,7 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) { PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_height = input.dims()[2]; const int input_height = input.dims()[2];
const int input_width = input.dims()[3]; const int input_width = input.dims()[3];
@ -115,11 +116,12 @@ class Pool2dGradFunctor<platform::CPUPlace, PoolProcess, T> {
float scale = 1.0 / pool_size; float scale = 1.0 / pool_size;
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
pool_compute.compute(input_data[h * input_width + w], pool_grad_process.compute(
output_data[ph * output_width + pw], input_data[h * input_width + w],
output_grad_data[ph * output_width + pw], output_data[ph * output_width + pw],
input_grad_data[h * input_width + w], output_grad_data[ph * output_width + pw],
static_cast<T>(scale)); input_grad_data[h * input_width + w],
static_cast<T>(scale));
} }
} }
} }
@ -198,21 +200,21 @@ template class MaxPool2dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool2dGradFunctor<platform::CPUPlace, double>; // template class MaxPool2dGradFunctor<platform::CPUPlace, double>;
template class Pool2dFunctor<platform::CPUPlace, template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
template class Pool2dFunctor<platform::CPUPlace, template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>; paddle::operators::math::AvgPool<float>, float>;
template class Pool2dGradFunctor< template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>; platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool2dGradFunctor< template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool2dFunctor<platform::CPUPlace, template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>; paddle::operators::math::MaxPool<double>, double>;
template class Pool2dFunctor<platform::CPUPlace, template class Pool2dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>; paddle::operators::math::AvgPool<double>, double>;
template class Pool2dGradFunctor< template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool2dGradFunctor< template class Pool2dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
template <typename PoolProcess, class T> template <typename PoolProcess, class T>
class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> { class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
@ -220,7 +222,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
void operator()(const platform::DeviceContext& context, void operator()(const platform::DeviceContext& context,
const framework::Tensor& input, framework::Tensor& output, const framework::Tensor& input, framework::Tensor& output,
std::vector<int>& ksize, std::vector<int>& strides, std::vector<int>& ksize, std::vector<int>& strides,
std::vector<int>& paddings, PoolProcess pool_compute) { std::vector<int>& paddings, PoolProcess pool_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
@ -260,11 +262,11 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
int wend = std::min(wstart + ksize_width, input_width); int wend = std::min(wstart + ksize_width, input_width);
wstart = std::max(wstart, 0); wstart = std::max(wstart, 0);
int output_idx = (pd * output_height + ph) * output_width + pw; int output_idx = (pd * output_height + ph) * output_width + pw;
T ele = pool_compute.initial(); T ele = pool_process.initial();
for (int d = dstart; d < dend; ++d) { for (int d = dstart; d < dend; ++d) {
for (int h = hstart; h < hend; ++h) { for (int h = hstart; h < hend; ++h) {
for (int w = wstart; w < wend; ++w) { for (int w = wstart; w < wend; ++w) {
pool_compute.compute( pool_process.compute(
ele, ele,
input_data[(d * input_height + h) * input_width + w]); input_data[(d * input_height + h) * input_width + w]);
} }
@ -272,7 +274,7 @@ class Pool3dFunctor<platform::CPUPlace, PoolProcess, T> {
} }
int pool_size = int pool_size =
(dend - dstart) * (hend - hstart) * (wend - wstart); (dend - dstart) * (hend - hstart) * (wend - wstart);
pool_compute.finalize(ele, static_cast<T>(pool_size)); pool_process.finalize(ele, static_cast<T>(pool_size));
output_data[output_idx] = ele; output_data[output_idx] = ele;
} }
} }
@ -292,7 +294,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
const framework::Tensor& output, const framework::Tensor& output,
const framework::Tensor& output_grad, std::vector<int>& ksize, const framework::Tensor& output_grad, std::vector<int>& ksize,
std::vector<int>& strides, std::vector<int>& paddings, std::vector<int>& strides, std::vector<int>& paddings,
PoolProcess pool_compute) { PoolProcess pool_grad_process) {
const int batch_size = input.dims()[0]; const int batch_size = input.dims()[0];
const int input_depth = input.dims()[2]; const int input_depth = input.dims()[2];
const int input_height = input.dims()[3]; const int input_height = input.dims()[3];
@ -343,7 +345,7 @@ class Pool3dGradFunctor<platform::CPUPlace, PoolProcess, T> {
int input_idx = (d * input_height + h) * input_width + w; int input_idx = (d * input_height + h) * input_width + w;
int output_idx = int output_idx =
(pd * output_height + ph) * output_width + pw; (pd * output_height + ph) * output_width + pw;
pool_compute.compute( pool_grad_process.compute(
input_data[input_idx], output_data[output_idx], input_data[input_idx], output_data[output_idx],
output_grad_data[output_idx], output_grad_data[output_idx],
input_grad_data[input_idx], static_cast<T>(scale)); input_grad_data[input_idx], static_cast<T>(scale));
@ -441,21 +443,21 @@ template class MaxPool3dGradFunctor<platform::CPUPlace, float>;
// template class MaxPool3dGradFunctor<platform::CPUPlace, double>; // template class MaxPool3dGradFunctor<platform::CPUPlace, double>;
template class Pool3dFunctor<platform::CPUPlace, template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<float>, float>; paddle::operators::math::MaxPool<float>, float>;
template class Pool3dFunctor<platform::CPUPlace, template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<float>, float>; paddle::operators::math::AvgPool<float>, float>;
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<float>, float>; platform::CPUPlace, paddle::operators::math::MaxPoolGrad<float>, float>;
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<float>, float>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<float>, float>;
template class Pool3dFunctor<platform::CPUPlace, template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::maxPool<double>, double>; paddle::operators::math::MaxPool<double>, double>;
template class Pool3dFunctor<platform::CPUPlace, template class Pool3dFunctor<platform::CPUPlace,
paddle::operators::math::avgPool<double>, double>; paddle::operators::math::AvgPool<double>, double>;
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::maxPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::MaxPoolGrad<double>, double>;
template class Pool3dGradFunctor< template class Pool3dGradFunctor<
platform::CPUPlace, paddle::operators::math::avgPoolGrad<double>, double>; platform::CPUPlace, paddle::operators::math::AvgPoolGrad<double>, double>;
} // namespace math } // namespace math
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

File diff suppressed because it is too large Load Diff

@ -22,11 +22,10 @@ namespace paddle {
namespace operators { namespace operators {
namespace math { namespace math {
////////////////////// //////////////////////
#define FLT_MAX __FLT_MAX__ #define FLT_MAX __FLT_MAX__ //
/////////////////////
template <class T> template <class T>
class maxPool { class MaxPool {
public: public:
DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); } DEVICE inline T initial() { return static_cast<T>(-FLT_MAX); }
DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; } DEVICE inline void compute(T& y, const T& x) { y = y > x ? y : x; }
@ -34,14 +33,14 @@ class maxPool {
}; };
template <class T> template <class T>
class avgPool { class AvgPool {
public: public:
DEVICE inline T initial() { return static_cast<T>(0); } DEVICE inline T initial() { return static_cast<T>(0); }
DEVICE inline void compute(T& y, const T& x) { y += x; } DEVICE inline void compute(T& y, const T& x) { y += x; }
DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; } DEVICE inline void finalize(T& y, const T& poo_size) { y /= poo_size; }
}; };
template <class T> template <class T>
class maxPoolGrad { class MaxPoolGrad {
public: public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) { T scale) {
@ -50,7 +49,7 @@ class maxPoolGrad {
}; };
template <class T> template <class T>
class avgPoolGrad { class AvgPoolGrad {
public: public:
DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx, DEVICE inline void compute(const T& x, const T& y, const T& dy, T& dx,
T scale) { T scale) {

@ -51,7 +51,7 @@ class PoolOp : public framework::OperatorWithKernel {
ksize[i] = static_cast<int>(in_x_dims[i + 2]); ksize[i] = static_cast<int>(in_x_dims[i + 2]);
} }
PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2, PADDLE_ENFORCE(in_x_dims.size() - ksize.size() == 2U,
"Input size and Pooling size should be consistent."); "Input size and Pooling size should be consistent.");
PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3, PADDLE_ENFORCE(ksize.size() == 2 || ksize.size() == 3,
"Pooling size should be 2 elements. or 3 elements."); "Pooling size should be 2 elements. or 3 elements.");
@ -79,7 +79,6 @@ class PoolOpGrad : public framework::OperatorWithKernel {
"X(Input) of Pooling should not be null."); "X(Input) of Pooling should not be null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Input@Grad of Pooling should not be null."); "Input@Grad of Pooling should not be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X")); ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
} }
}; };
@ -98,66 +97,36 @@ class Pool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCHW."); "The format of output tensor is also NCHW.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"poolingType of pooling operator." "PoolingType of pooling operator."
"str constant equal to 'max' or 'avg'"); "Str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"Pooling size(depth, height, width) of pooling operator." "Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified."); "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>("strides", AddAttr<std::vector<int>>("strides",
"strides(height, width) of pooling operator." "Strides(height, width) of pooling operator."
"default {1,1}") "Default {1,1}")
.SetDefault({1, 1}) .SetDefault({1, 1}); // TODO(Add checker)
.AddCustomChecker(GreaterThanChecker_pool({0, 0}));
AddAttr<std::vector<int>>("paddings", AddAttr<std::vector<int>>("paddings",
"paddings(height, width) of pooling operator." "Paddings(height, width) of pooling operator."
"default {0,0}") "Default {0,0}.")
.SetDefault({0, 0}) .SetDefault({0, 0}); // TODO(Add checker)
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0}));
AddComment(R"DOC( AddComment(R"DOC(
The pooling2d operation calculates the output based on The pooling2d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
}; };
class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker { class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Pool3dOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
@ -173,67 +142,36 @@ class Pool3dOpMaker : public framework::OpProtoAndCheckerMaker {
"The format of output tensor is also NCDHW."); "The format of output tensor is also NCDHW.");
AddAttr<std::string>("poolingType", AddAttr<std::string>("poolingType",
"poolingType of pooling operator." "PoolingType of pooling operator."
"str constant equal to 'max' or 'avg'"); "str constant equal to 'max' or 'avg'.")
.InEnum({"max", "avg"});
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"ksize", "ksize",
"pooling size(depth, height, width) of pooling operator." "Pooling size(depth, height, width) of pooling operator."
"If globalPooling = true, ksize is ignored and need not be specified."); "If globalPooling = true, ksize is ignored and need not be "
"specified."); // TODO(Add checker)
AddAttr<bool>( AddAttr<bool>(
"globalPooling", "globalPooling",
"whether to use the globalPooling." "Whether to use the globalPooling."
"int constant equal to false or true" "Bool constant equal to false or true."
"default false" "Default false."
"If globalPooling = true, ksize is ignored and need not be specified.") "If globalPooling = true, ksize is ignored and need not be specified.")
.SetDefault(false); .SetDefault(false);
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"strides", "strides",
"strides(depth, height, width) of pooling operator." "Strides(depth, height, width) of pooling operator."
"default {1,1,1}") "Default {1,1,1}.")
.SetDefault({1, 1, 1}) .SetDefault({1, 1, 1}); // TODO(Add checker)
.AddCustomChecker(GreaterThanChecker_pool({0, 0, 0}));
AddAttr<std::vector<int>>( AddAttr<std::vector<int>>(
"paddings", "paddings",
"paddings(depth, height, width) of pooling operator." "Paddings(depth, height, width) of pooling operator."
"default {0,0,0}") "Default {0,0,0}.")
.SetDefault({0, 0, 0}) .SetDefault({0, 0, 0}); // TODO(Add checker)
.AddCustomChecker(EqualGreaterThanChecker_pool({0, 0, 0}));
AddComment(R"DOC( AddComment(R"DOC(
The pooling3d operation calculates the output based on The pooling3d operation calculates the output based on
the input, poolingType and ksize, strides, paddings parameters. the input, poolingType and ksize, strides, paddings parameters.
)DOC"); )DOC");
} }
private:
struct GreaterThanChecker_pool {
public:
explicit GreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] > lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
struct EqualGreaterThanChecker_pool {
public:
explicit EqualGreaterThanChecker_pool(std::vector<int> lower_bound)
: lower_bound_(lower_bound) {}
void operator()(std::vector<int> &value) const {
PADDLE_ENFORCE(value.size() == lower_bound_.size(), "equal check fails.");
for (size_t i = 0; i < value.size(); ++i) {
PADDLE_ENFORCE(value[i] >= lower_bound_[i], "larger_than check fails.");
}
}
private:
std::vector<int> lower_bound_;
};
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle

@ -31,12 +31,11 @@ class PoolKernel : public framework::OpKernel {
const Tensor* in_x = context.Input<Tensor>("X"); const Tensor* in_x = context.Input<Tensor>("X");
Tensor* out = context.Output<Tensor>("Out"); Tensor* out = context.Output<Tensor>("Out");
bool global_pooling = context.Attr<bool>("globalPooling");
std::string pooling_type = context.Attr<std::string>("poolingType"); std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) { for (size_t i = 0; i < ksize.size(); ++i) {
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
@ -46,17 +45,17 @@ class PoolKernel : public framework::OpKernel {
case 2: { case 2: {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::Pool2dFunctor< paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::maxPool<T>, T> Place, paddle::operators::math::MaxPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::maxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process); paddings, pool_process);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dFunctor< paddle::operators::math::Pool2dFunctor<
Place, paddle::operators::math::avgPool<T>, T> Place, paddle::operators::math::AvgPool<T>, T>
pool2d_forward; pool2d_forward;
paddle::operators::math::avgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool2d_forward(context.device_context(), *in_x, *out, ksize, strides, pool2d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process); paddings, pool_process);
} }
@ -64,16 +63,16 @@ class PoolKernel : public framework::OpKernel {
case 3: { case 3: {
if (pooling_type == "max") { if (pooling_type == "max") {
paddle::operators::math::Pool3dFunctor< paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::maxPool<T>, T> Place, paddle::operators::math::MaxPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::maxPool<T> pool_process; paddle::operators::math::MaxPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process); paddings, pool_process);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dFunctor< paddle::operators::math::Pool3dFunctor<
Place, paddle::operators::math::avgPool<T>, T> Place, paddle::operators::math::AvgPool<T>, T>
pool3d_forward; pool3d_forward;
paddle::operators::math::avgPool<T> pool_process; paddle::operators::math::AvgPool<T> pool_process;
pool3d_forward(context.device_context(), *in_x, *out, ksize, strides, pool3d_forward(context.device_context(), *in_x, *out, ksize, strides,
paddings, pool_process); paddings, pool_process);
} }
@ -92,13 +91,12 @@ class PoolGradKernel : public framework::OpKernel {
context.Input<Tensor>(framework::GradVarName("Out")); context.Input<Tensor>(framework::GradVarName("Out"));
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X")); Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
bool global_pooling = context.Attr<bool>("globalPooling");
std::string pooling_type = context.Attr<std::string>("poolingType"); std::string pooling_type = context.Attr<std::string>("poolingType");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize"); std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
std::vector<int> strides = context.Attr<std::vector<int>>("strides"); std::vector<int> strides = context.Attr<std::vector<int>>("strides");
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings"); std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
if (global_pooling) { if (context.Attr<bool>("globalPooling")) {
for (size_t i = 0; i < ksize.size(); ++i) for (size_t i = 0; i < ksize.size(); ++i)
ksize[i] = static_cast<int>(in_x->dims()[i + 2]); ksize[i] = static_cast<int>(in_x->dims()[i + 2]);
} }
@ -118,9 +116,9 @@ class PoolGradKernel : public framework::OpKernel {
*out_grad, ksize, strides, paddings); *out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool2dGradFunctor< paddle::operators::math::Pool2dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T> Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool2d_backward; pool2d_backward;
paddle::operators::math::avgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool2d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings, pool_process);
} }
@ -133,9 +131,9 @@ class PoolGradKernel : public framework::OpKernel {
*out_grad, ksize, strides, paddings); *out_grad, ksize, strides, paddings);
} else if (pooling_type == "avg") { } else if (pooling_type == "avg") {
paddle::operators::math::Pool3dGradFunctor< paddle::operators::math::Pool3dGradFunctor<
Place, paddle::operators::math::avgPoolGrad<T>, T> Place, paddle::operators::math::AvgPoolGrad<T>, T>
pool3d_backward; pool3d_backward;
paddle::operators::math::avgPoolGrad<T> pool_process; paddle::operators::math::AvgPoolGrad<T> pool_process;
pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out, pool3d_backward(context.device_context(), *in_x, *in_x_grad, *out,
*out_grad, ksize, strides, paddings, pool_process); *out_grad, ksize, strides, paddings, pool_process);
} }

Loading…
Cancel
Save