|
|
|
@ -21,15 +21,13 @@ limitations under the License. */
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace operators {
|
|
|
|
|
|
|
|
|
|
using Tensor = framework::Tensor;
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class UnpoolKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* in_x = context.Input<Tensor>("X");
|
|
|
|
|
const Tensor* in_y = context.Input<Tensor>("Y");
|
|
|
|
|
auto * out = context.Output<Tensor>("Out");
|
|
|
|
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
|
|
|
|
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
|
|
|
|
|
auto * out = context.Output<framework::Tensor>("Out");
|
|
|
|
|
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
|
|
|
|
|
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
@ -39,28 +37,22 @@ class UnpoolKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::SetConstant<Place, T> set_zero;
|
|
|
|
|
set_zero(context.device_context(), out, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
switch (ksize.size()) {
|
|
|
|
|
case 2: {
|
|
|
|
|
if (unpoolingtype == "max") {
|
|
|
|
|
math::Unpool2dMaxFunctor<Place, T> unpool2d_max_forward;
|
|
|
|
|
unpool2d_max_forward(context.device_context(), *in_x, *in_y, out);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
default: { PADDLE_THROW("Pool op only supports 2D input."); }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename Place, typename T>
|
|
|
|
|
class UnpoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* in_x = context.Input<Tensor>("X");
|
|
|
|
|
const Tensor* in_y = context.Input<Tensor>("Y");
|
|
|
|
|
const Tensor* out = context.Input<Tensor>("Out");
|
|
|
|
|
const Tensor* out_grad =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
Tensor* in_x_grad = context.Output<Tensor>(framework::GradVarName("X"));
|
|
|
|
|
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
|
|
|
|
|
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
|
|
|
|
|
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
|
|
|
|
|
const framework::Tensor* out_grad =
|
|
|
|
|
context.Input<framework::Tensor>(framework::GradVarName("Out"));
|
|
|
|
|
framework::Tensor* in_x_grad =
|
|
|
|
|
context.Output<framework::Tensor>(framework::GradVarName("X"));
|
|
|
|
|
std::string unpoolingtype = context.Attr<std::string>("unpoolingtype");
|
|
|
|
|
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
@ -70,19 +62,12 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
math::SetConstant<Place, T> zero;
|
|
|
|
|
if (in_x_grad) {
|
|
|
|
|
in_x_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
zero(device_ctx, in_x_grad, static_cast<T>(0.0));
|
|
|
|
|
zero(device_ctx, in_x_grad, static_cast<T>(0));
|
|
|
|
|
}
|
|
|
|
|
switch (ksize.size()) {
|
|
|
|
|
case 2: {
|
|
|
|
|
if (unpoolingtype == "max") {
|
|
|
|
|
math::Unpool2dMaxGradFunctor<Place, T> unpool2d_max_backward;
|
|
|
|
|
unpool2d_max_backward(context.device_context(), *in_x, *in_y, in_x_grad,
|
|
|
|
|
*out, *out_grad);
|
|
|
|
|
}
|
|
|
|
|
} break;
|
|
|
|
|
default: { PADDLE_THROW("Unpool op only supports 2D input."); }
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|