|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include "paddle/fluid/framework/eigen.h"
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/blas.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/depthwise_conv.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/im2col.h"
|
|
|
|
|
#include "paddle/fluid/operators/math/vol2col.h"
|
|
|
|
|
|
|
|
|
@ -316,5 +317,74 @@ class GemmConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class DepthwiseConvTransposeKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
|
|
Tensor filter = *context.Input<Tensor>("Filter");
|
|
|
|
|
Tensor* output = context.Output<Tensor>("Output");
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
|
|
|
|
|
int groups = context.Attr<int>("groups");
|
|
|
|
|
PADDLE_ENFORCE_EQ(groups, filter.dims()[0]);
|
|
|
|
|
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
std::vector<int> dilations = context.Attr<std::vector<int>>("dilations");
|
|
|
|
|
for (auto v : dilations) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(v, 1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
output->mutable_data<T>(context.GetPlace());
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
set_zero(dev_ctx, output, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
math::DepthwiseConvInputGradFunctor<DeviceContext, T>
|
|
|
|
|
depthwiseConvInputGrad;
|
|
|
|
|
depthwiseConvInputGrad(dev_ctx, *output, filter, *input, strides, paddings,
|
|
|
|
|
output);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename DeviceContext, typename T>
|
|
|
|
|
class DepthwiseConvTransposeGradKernel : public framework::OpKernel<T> {
|
|
|
|
|
public:
|
|
|
|
|
void Compute(const framework::ExecutionContext& context) const override {
|
|
|
|
|
const Tensor* input = context.Input<Tensor>("Input");
|
|
|
|
|
const Tensor* output_grad =
|
|
|
|
|
context.Input<Tensor>(framework::GradVarName("Output"));
|
|
|
|
|
Tensor* input_grad =
|
|
|
|
|
context.Output<Tensor>(framework::GradVarName("Input"));
|
|
|
|
|
Tensor* filter_grad =
|
|
|
|
|
context.Output<Tensor>(framework::GradVarName("Filter"));
|
|
|
|
|
Tensor filter = *context.Input<Tensor>("Filter");
|
|
|
|
|
|
|
|
|
|
if (!input_grad && !filter_grad) return;
|
|
|
|
|
|
|
|
|
|
auto& dev_ctx = context.template device_context<DeviceContext>();
|
|
|
|
|
std::vector<int> strides = context.Attr<std::vector<int>>("strides");
|
|
|
|
|
std::vector<int> paddings = context.Attr<std::vector<int>>("paddings");
|
|
|
|
|
|
|
|
|
|
if (input_grad) {
|
|
|
|
|
math::DepthwiseConvFunctor<DeviceContext, T> depthwiseConv;
|
|
|
|
|
depthwiseConv(dev_ctx, *output_grad, filter, strides, paddings,
|
|
|
|
|
input_grad);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (filter_grad) {
|
|
|
|
|
math::SetConstant<DeviceContext, T> set_zero;
|
|
|
|
|
filter_grad->mutable_data<T>(context.GetPlace());
|
|
|
|
|
set_zero(dev_ctx, filter_grad, static_cast<T>(0));
|
|
|
|
|
|
|
|
|
|
math::DepthwiseConvFilterGradFunctor<DeviceContext, T>
|
|
|
|
|
depthwiseConvFilterGrad;
|
|
|
|
|
depthwiseConvFilterGrad(dev_ctx, *output_grad, *input, strides, paddings,
|
|
|
|
|
filter_grad);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|