|
|
|
@ -10,6 +10,7 @@
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/interpolate_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
#include "paddle/fluid/framework/op_registry.h"
|
|
|
|
@ -194,21 +195,46 @@ class InterpolateOpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
|
|
|
|
|
framework::OpKernelType GetExpectedKernelType(
|
|
|
|
|
const framework::ExecutionContext& ctx) const override {
|
|
|
|
|
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
return framework::OpKernelType(
|
|
|
|
|
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
|
|
|
|
|
ctx.GetPlace());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType(ForwardOp().Type() + "_grad");
|
|
|
|
|
op->SetInput("X", Input("X"));
|
|
|
|
|
if (ForwardOp().Inputs().count("OutSize") > 0) {
|
|
|
|
|
op->SetInput("OutSize", Input("OutSize"));
|
|
|
|
|
}
|
|
|
|
|
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(InterpolateGradNoNeedBufferVarsInference,
|
|
|
|
|
"X");
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
|
namespace ops = paddle::operators;
|
|
|
|
|
REGISTER_OPERATOR(bilinear_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad);
|
|
|
|
|
ops::InterpolateGradDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(bilinear_interp_grad, ops::InterpolateOpGrad,
|
|
|
|
|
ops::InterpolateGradNoNeedBufferVarsInference);
|
|
|
|
|
REGISTER_OPERATOR(nearest_interp, ops::InterpolateOp, ops::InterpolateOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad);
|
|
|
|
|
ops::InterpolateGradDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(nearest_interp_grad, ops::InterpolateOpGrad,
|
|
|
|
|
ops::InterpolateGradNoNeedBufferVarsInference);
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(bilinear_interp, ops::InterpolateKernel<float>,
|
|
|
|
|
ops::InterpolateKernel<double>,
|
|
|
|
|
ops::InterpolateKernel<uint8_t>);
|
|
|
|
|