|
|
|
@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/fluid/operators/conv_transpose_op.h"
|
|
|
|
|
#include <memory>
|
|
|
|
|
#include <string>
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
|
|
|
@ -344,6 +345,28 @@ framework::OpKernelType ConvTransposeOpGrad::GetExpectedKernelType(
|
|
|
|
|
ctx.GetPlace(), layout_, library_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
class ConvTransposeGradOpDescMaker : public framework::SingleGradOpDescMaker {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
std::unique_ptr<framework::OpDesc> Apply() const override {
|
|
|
|
|
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
|
|
|
|
|
op->SetType(ForwardOp().Type() + "_grad");
|
|
|
|
|
op->SetInput("Input", Input("Input"));
|
|
|
|
|
op->SetInput("Filter", Input("Filter"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Filter"), InputGrad("Filter"));
|
|
|
|
|
if (ForwardOp().Inputs().count("Bias") > 0) {
|
|
|
|
|
op->SetInput("Bias", Input("Bias"));
|
|
|
|
|
op->SetOutput(framework::GradVarName("Bias"), InputGrad("Bias"));
|
|
|
|
|
}
|
|
|
|
|
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
|
|
|
|
|
op->SetAttrMap(Attrs());
|
|
|
|
|
return op;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
} // namespace operators
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|
|
|
|
@ -352,7 +375,7 @@ namespace ops = paddle::operators;
|
|
|
|
|
// conv2d_transpose
|
|
|
|
|
REGISTER_OPERATOR(conv2d_transpose, ops::ConvTransposeOp,
|
|
|
|
|
ops::Conv2DTransposeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
ops::ConvTransposeGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(conv2d_transpose_grad, ops::ConvTransposeOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
@ -368,7 +391,7 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
// conv3d_transpose
|
|
|
|
|
REGISTER_OPERATOR(conv3d_transpose, ops::ConvTransposeOp,
|
|
|
|
|
ops::Conv3DTransposeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
ops::ConvTransposeGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(conv3d_transpose_grad, ops::ConvTransposeOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
@ -384,7 +407,7 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
// depthwise conv2d_transpose
|
|
|
|
|
REGISTER_OPERATOR(depthwise_conv2d_transpose, ops::ConvTransposeOp,
|
|
|
|
|
ops::Conv2DTransposeOpMaker,
|
|
|
|
|
paddle::framework::DefaultGradOpDescMaker<true>);
|
|
|
|
|
ops::ConvTransposeGradOpDescMaker);
|
|
|
|
|
REGISTER_OPERATOR(depthwise_conv2d_transpose_grad, ops::ConvTransposeOpGrad);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|