|
|
|
@ -272,6 +272,20 @@ class Transpose2GradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
|
class Transpose2DoubleGradMaker : public framework::SingleGradOpMaker<T> {
|
|
|
|
|
public:
|
|
|
|
|
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
|
|
|
|
|
|
|
|
|
void Apply(GradOpPtr<T> grad_op) const override {
|
|
|
|
|
grad_op->SetType("transpose2");
|
|
|
|
|
grad_op->SetInput("X", this->OutputGrad(framework::GradVarName("X")));
|
|
|
|
|
grad_op->SetOutput("Out", this->InputGrad(framework::GradVarName("Out")));
|
|
|
|
|
grad_op->SetOutput("XShape", this->Input("XShape"));
|
|
|
|
|
grad_op->SetAttrMap(this->Attrs());
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class Transpose2OpGrad : public framework::OperatorWithKernel {
|
|
|
|
|
public:
|
|
|
|
|
using framework::OperatorWithKernel::OperatorWithKernel;
|
|
|
|
@ -338,7 +352,9 @@ REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
REGISTER_OPERATOR(transpose2, ops::Transpose2Op, ops::Transpose2OpMaker,
|
|
|
|
|
ops::Transpose2GradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Transpose2GradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad);
|
|
|
|
|
REGISTER_OPERATOR(transpose2_grad, ops::Transpose2OpGrad,
|
|
|
|
|
ops::Transpose2DoubleGradMaker<paddle::framework::OpDesc>,
|
|
|
|
|
ops::Transpose2DoubleGradMaker<paddle::imperative::OpBase>);
|
|
|
|
|
|
|
|
|
|
REGISTER_OP_CPU_KERNEL(
|
|
|
|
|
transpose2, ops::TransposeKernel<paddle::platform::CPUDeviceContext, float>,
|
|
|
|
|