[Complex] Add real & imag op and api for complex tensor (#29672)
	
		
	
				
					
				
			* add complex real op & api & unittest * add imag op & api & unittest * refactor op impl * revert simplify writing due to complile failed * polish details * polish grad op coderevert-31562-mean
							parent
							
								
									9eff1a674f
								
							
						
					
					
						commit
						6cfa59de1b
					
				@ -0,0 +1,106 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/imag_op.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class ImagOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Imag");
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Imag");
 | 
				
			||||
 | 
				
			||||
    auto x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    ctx->SetOutputDim("Out", x_dims);
 | 
				
			||||
    ctx->ShareLoD("X", "Out");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class ImagOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("X", "(Tensor), The input tensor of imag op.");
 | 
				
			||||
    AddOutput("Out", "(Tensor), The output tensor of imag op.");
 | 
				
			||||
    AddComment(R"DOC(
 | 
				
			||||
Imag Operator.
 | 
				
			||||
 | 
				
			||||
This operator is used to get a new tensor containing imaginary values
 | 
				
			||||
from a tensor with complex data type.
 | 
				
			||||
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class ImagGradOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
 | 
				
			||||
                   "Out@Grad", "ImagGrad");
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
 | 
				
			||||
                   "X@Grad", "ImagGrad");
 | 
				
			||||
 | 
				
			||||
    auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
 | 
				
			||||
    ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    auto dtype = OperatorWithKernel::IndicateVarDataType(
 | 
				
			||||
        ctx, framework::GradVarName("Out"));
 | 
				
			||||
    auto complex_dtype = framework::ToComplexType(dtype);
 | 
				
			||||
    return framework::OpKernelType(complex_dtype, ctx.GetPlace());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class ImagGradOpMaker : public framework::SingleGradOpMaker<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
 | 
				
			||||
  void Apply(GradOpPtr<T> grad_op) const override {
 | 
				
			||||
    grad_op->SetType("imag_grad");
 | 
				
			||||
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
 | 
				
			||||
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
DECLARE_INPLACE_OP_INFERER(ImagOpInplaceInferer, {"X", "Out"});
 | 
				
			||||
DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
 | 
				
			||||
                           {framework::GradVarName("Out"),
 | 
				
			||||
                            framework::GradVarName("X")});
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
 | 
				
			||||
REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
 | 
				
			||||
                  ops::ImagGradOpMaker<paddle::framework::OpDesc>,
 | 
				
			||||
                  ops::ImagGradOpMaker<paddle::imperative::OpBase>);
 | 
				
			||||
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                             paddle::platform::complex64>,
 | 
				
			||||
                       ops::ImagKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                       paddle::platform::complex128>);
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(imag_grad,
 | 
				
			||||
                       ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                           paddle::platform::complex64>,
 | 
				
			||||
                       ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                           paddle::platform::complex128>);
 | 
				
			||||
@ -0,0 +1,28 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/imag_op.h"
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(imag,
 | 
				
			||||
                        ops::ImagKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                        paddle::platform::complex64>,
 | 
				
			||||
                        ops::ImagKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                        paddle::platform::complex128>);
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(imag_grad,
 | 
				
			||||
                        ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                            paddle::platform::complex64>,
 | 
				
			||||
                        ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                            paddle::platform::complex128>);
 | 
				
			||||
@ -0,0 +1,66 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/data_type.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/complex_functors.h"
 | 
				
			||||
#include "paddle/fluid/platform/for_range.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class ImagKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const {
 | 
				
			||||
    const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
 | 
				
			||||
    framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
 | 
				
			||||
 | 
				
			||||
    auto numel = x->numel();
 | 
				
			||||
    auto* x_data = x->data<T>();
 | 
				
			||||
    auto* out_data = out->mutable_data<math::Real<T>>(
 | 
				
			||||
        ctx.GetPlace(), static_cast<size_t>(numel * sizeof(math::Real<T>)));
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
    platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
 | 
				
			||||
    math::ImagFunctor<T> functor(x_data, out_data, numel);
 | 
				
			||||
    for_range(functor);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class ImagGradKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const {
 | 
				
			||||
    const framework::Tensor* d_out =
 | 
				
			||||
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
 | 
				
			||||
    framework::Tensor* d_x =
 | 
				
			||||
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
 | 
				
			||||
 | 
				
			||||
    auto numel = d_out->numel();
 | 
				
			||||
    auto* dout_data = d_out->data<math::Real<T>>();
 | 
				
			||||
    auto* dx_data = d_x->mutable_data<T>(
 | 
				
			||||
        ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
    platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
 | 
				
			||||
    math::ImagToComplexFunctor<T> functor(dout_data, dx_data, numel);
 | 
				
			||||
    for_range(functor);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,140 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include <type_traits>
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/platform/complex128.h"
 | 
				
			||||
#include "paddle/fluid/platform/complex64.h"
 | 
				
			||||
#include "paddle/fluid/platform/hostdevice.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
namespace math {
 | 
				
			||||
 | 
				
			||||
template <bool B, typename T>
 | 
				
			||||
struct cond {
 | 
				
			||||
  static constexpr bool value = B;
 | 
				
			||||
  using type = T;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <bool B, typename TrueF, typename FalseF>
 | 
				
			||||
struct eval_if {
 | 
				
			||||
  using type = typename TrueF::type;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename TrueF, typename FalseF>
 | 
				
			||||
struct eval_if<false, TrueF, FalseF> {
 | 
				
			||||
  using type = typename FalseF::type;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <bool B, typename T, typename F>
 | 
				
			||||
using eval_if_t = typename eval_if<B, T, F>::type;
 | 
				
			||||
 | 
				
			||||
template <typename Head, typename... Tail>
 | 
				
			||||
struct select {
 | 
				
			||||
  using type = eval_if_t<Head::value, Head, select<Tail...>>;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename Head, typename... Tail>
 | 
				
			||||
using select_t = typename select<Head, Tail...>::type;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
using Real =
 | 
				
			||||
    select_t<cond<std::is_same<T, platform::complex64>::value, float>,
 | 
				
			||||
             cond<std::is_same<T, platform::complex128>::value, double>, T>;
 | 
				
			||||
 | 
				
			||||
template <typename T, typename RealT>
 | 
				
			||||
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
 | 
				
			||||
 | 
				
			||||
// There are no NoComplex cases now, implement later if needed
 | 
				
			||||
template <typename T, typename RealT>
 | 
				
			||||
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
 | 
				
			||||
 | 
				
			||||
template <typename T, typename Enable = void>
 | 
				
			||||
struct RealFunctor;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
struct RealFunctor<T, Complex<T, Real<T>>> {
 | 
				
			||||
 public:
 | 
				
			||||
  RealFunctor(const T* input, Real<T>* output, int64_t numel)
 | 
				
			||||
      : input_(input), output_(output), numel_(numel) {}
 | 
				
			||||
 | 
				
			||||
  HOSTDEVICE void operator()(int64_t idx) const {
 | 
				
			||||
    output_[idx] = input_[idx].real;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 private:
 | 
				
			||||
  const T* input_;
 | 
				
			||||
  Real<T>* output_;
 | 
				
			||||
  int64_t numel_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T, typename Enable = void>
 | 
				
			||||
struct ImagFunctor;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
struct ImagFunctor<T, Complex<T, Real<T>>> {
 | 
				
			||||
  ImagFunctor(const T* input, Real<T>* output, int64_t numel)
 | 
				
			||||
      : input_(input), output_(output), numel_(numel) {}
 | 
				
			||||
 | 
				
			||||
  HOSTDEVICE void operator()(int64_t idx) const {
 | 
				
			||||
    output_[idx] = input_[idx].imag;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const T* input_;
 | 
				
			||||
  Real<T>* output_;
 | 
				
			||||
  int64_t numel_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T, typename Enable = void>
 | 
				
			||||
struct RealToComplexFunctor;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
struct RealToComplexFunctor<T, Complex<T, Real<T>>> {
 | 
				
			||||
  RealToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
 | 
				
			||||
      : input_(input), output_(output), numel_(numel) {}
 | 
				
			||||
 | 
				
			||||
  HOSTDEVICE void operator()(int64_t idx) const {
 | 
				
			||||
    output_[idx].real = input_[idx];
 | 
				
			||||
    output_[idx].imag = 0;
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const Real<T>* input_;
 | 
				
			||||
  T* output_;
 | 
				
			||||
  int64_t numel_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T, typename Enable = void>
 | 
				
			||||
struct ImagToComplexFunctor;
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
 | 
				
			||||
  ImagToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
 | 
				
			||||
      : input_(input), output_(output), numel_(numel) {}
 | 
				
			||||
 | 
				
			||||
  HOSTDEVICE void operator()(int64_t idx) const {
 | 
				
			||||
    output_[idx].real = 0;
 | 
				
			||||
    output_[idx].imag = input_[idx];
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
  const Real<T>* input_;
 | 
				
			||||
  T* output_;
 | 
				
			||||
  int64_t numel_;
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace math
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,105 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/real_op.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
class RealOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Real");
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Real");
 | 
				
			||||
 | 
				
			||||
    auto x_dims = ctx->GetInputDim("X");
 | 
				
			||||
    ctx->SetOutputDim("Out", x_dims);
 | 
				
			||||
    ctx->ShareLoD("X", "Out");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class RealOpMaker : public framework::OpProtoAndCheckerMaker {
 | 
				
			||||
 public:
 | 
				
			||||
  void Make() override {
 | 
				
			||||
    AddInput("X", "(Tensor), The input tensor of real op.");
 | 
				
			||||
    AddOutput("Out", "(Tensor), The output tensor of real op.");
 | 
				
			||||
    AddComment(R"DOC( 
 | 
				
			||||
Real Operator. 
 | 
				
			||||
 | 
				
			||||
This operator is used to get a new tensor containing real values 
 | 
				
			||||
from a tensor with complex data type.
 | 
				
			||||
 | 
				
			||||
)DOC");
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
class RealGradOp : public framework::OperatorWithKernel {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::OperatorWithKernel::OperatorWithKernel;
 | 
				
			||||
  void InferShape(framework::InferShapeContext* ctx) const override {
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
 | 
				
			||||
                   "Out@Grad", "RealGrad");
 | 
				
			||||
    OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
 | 
				
			||||
                   "X@Grad", "RealGrad");
 | 
				
			||||
 | 
				
			||||
    auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
 | 
				
			||||
    ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
 | 
				
			||||
  }
 | 
				
			||||
 | 
				
			||||
 protected:
 | 
				
			||||
  framework::OpKernelType GetExpectedKernelType(
 | 
				
			||||
      const framework::ExecutionContext& ctx) const override {
 | 
				
			||||
    auto dtype = OperatorWithKernel::IndicateVarDataType(
 | 
				
			||||
        ctx, framework::GradVarName("Out"));
 | 
				
			||||
    auto complex_dtype = framework::ToComplexType(dtype);
 | 
				
			||||
    return framework::OpKernelType(complex_dtype, ctx.GetPlace());
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename T>
 | 
				
			||||
class RealGradOpMaker : public framework::SingleGradOpMaker<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
 | 
				
			||||
  void Apply(GradOpPtr<T> grad_op) const override {
 | 
				
			||||
    grad_op->SetType("real_grad");
 | 
				
			||||
    grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
 | 
				
			||||
    grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
DECLARE_INPLACE_OP_INFERER(RealOpInplaceInferer, {"X", "Out"});
 | 
				
			||||
DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
 | 
				
			||||
                           {framework::GradVarName("Out"),
 | 
				
			||||
                            framework::GradVarName("X")});
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
 | 
				
			||||
REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
 | 
				
			||||
                  ops::RealGradOpMaker<::paddle::framework::OpDesc>,
 | 
				
			||||
                  ops::RealGradOpMaker<::paddle::imperative::OpBase>);
 | 
				
			||||
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                             paddle::platform::complex64>,
 | 
				
			||||
                       ops::RealKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                       paddle::platform::complex128>);
 | 
				
			||||
REGISTER_OP_CPU_KERNEL(real_grad,
 | 
				
			||||
                       ops::RealGradKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                           paddle::platform::complex64>,
 | 
				
			||||
                       ops::RealGradKernel<paddle::platform::CPUDeviceContext,
 | 
				
			||||
                                           paddle::platform::complex128>);
 | 
				
			||||
@ -0,0 +1,28 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/operators/real_op.h"
 | 
				
			||||
 | 
				
			||||
namespace ops = paddle::operators;
 | 
				
			||||
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(real,
 | 
				
			||||
                        ops::RealKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                        paddle::platform::complex64>,
 | 
				
			||||
                        ops::RealKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                        paddle::platform::complex128>);
 | 
				
			||||
REGISTER_OP_CUDA_KERNEL(real_grad,
 | 
				
			||||
                        ops::RealGradKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                            paddle::platform::complex64>,
 | 
				
			||||
                        ops::RealGradKernel<paddle::platform::CUDADeviceContext,
 | 
				
			||||
                                            paddle::platform::complex128>);
 | 
				
			||||
@ -0,0 +1,66 @@
 | 
				
			||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
 | 
				
			||||
Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
you may not use this file except in compliance with the License.
 | 
				
			||||
You may obtain a copy of the License at
 | 
				
			||||
 | 
				
			||||
    http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
 | 
				
			||||
Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
See the License for the specific language governing permissions and
 | 
				
			||||
limitations under the License. */
 | 
				
			||||
 | 
				
			||||
#pragma once
 | 
				
			||||
 | 
				
			||||
#include "paddle/fluid/framework/data_type.h"
 | 
				
			||||
#include "paddle/fluid/framework/op_registry.h"
 | 
				
			||||
#include "paddle/fluid/operators/math/complex_functors.h"
 | 
				
			||||
#include "paddle/fluid/platform/for_range.h"
 | 
				
			||||
 | 
				
			||||
namespace paddle {
 | 
				
			||||
namespace operators {
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class RealKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const {
 | 
				
			||||
    const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
 | 
				
			||||
    framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
 | 
				
			||||
 | 
				
			||||
    auto numel = x->numel();
 | 
				
			||||
    auto* x_data = x->data<T>();
 | 
				
			||||
    auto* out_data = out->mutable_data<math::Real<T>>(
 | 
				
			||||
        ctx.GetPlace(), static_cast<size_t>(numel * sizeof(math::Real<T>)));
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
    platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
 | 
				
			||||
    math::RealFunctor<T> functor(x_data, out_data, numel);
 | 
				
			||||
    for_range(functor);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
template <typename DeviceContext, typename T>
 | 
				
			||||
class RealGradKernel : public framework::OpKernel<T> {
 | 
				
			||||
 public:
 | 
				
			||||
  void Compute(const framework::ExecutionContext& ctx) const {
 | 
				
			||||
    const framework::Tensor* d_out =
 | 
				
			||||
        ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
 | 
				
			||||
    framework::Tensor* d_x =
 | 
				
			||||
        ctx.Output<framework::Tensor>(framework::GradVarName("X"));
 | 
				
			||||
 | 
				
			||||
    auto numel = d_out->numel();
 | 
				
			||||
    auto* dout_data = d_out->data<math::Real<T>>();
 | 
				
			||||
    auto* dx_data = d_x->mutable_data<T>(
 | 
				
			||||
        ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
 | 
				
			||||
 | 
				
			||||
    auto& dev_ctx = ctx.template device_context<DeviceContext>();
 | 
				
			||||
    platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
 | 
				
			||||
    math::RealToComplexFunctor<T> functor(dout_data, dx_data, numel);
 | 
				
			||||
    for_range(functor);
 | 
				
			||||
  }
 | 
				
			||||
};
 | 
				
			||||
 | 
				
			||||
}  // namespace operators
 | 
				
			||||
}  // namespace paddle
 | 
				
			||||
@ -0,0 +1,167 @@
 | 
				
			||||
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
 | 
				
			||||
#
 | 
				
			||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
				
			||||
# you may not use this file except in compliance with the License.
 | 
				
			||||
# You may obtain a copy of the License at
 | 
				
			||||
#
 | 
				
			||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
				
			||||
#
 | 
				
			||||
# Unless required by applicable law or agreed to in writing, software
 | 
				
			||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
				
			||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
				
			||||
# See the License for the specific language governing permissions and
 | 
				
			||||
# limitations under the License.
 | 
				
			||||
 | 
				
			||||
from __future__ import print_function
 | 
				
			||||
 | 
				
			||||
import unittest
 | 
				
			||||
import numpy as np
 | 
				
			||||
 | 
				
			||||
import paddle
 | 
				
			||||
import paddle.fluid as fluid
 | 
				
			||||
import paddle.static as static
 | 
				
			||||
from op_test import OpTest
 | 
				
			||||
 | 
				
			||||
numpy_apis = {
 | 
				
			||||
    "real": np.real,
 | 
				
			||||
    "imag": np.imag,
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
paddle_apis = {
 | 
				
			||||
    "real": paddle.real,
 | 
				
			||||
    "imag": paddle.imag,
 | 
				
			||||
}
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestRealOp(OpTest):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        # switch to static
 | 
				
			||||
        paddle.enable_static()
 | 
				
			||||
        # op test attrs
 | 
				
			||||
        self.op_type = "real"
 | 
				
			||||
        self.dtype = np.float64
 | 
				
			||||
        self.init_input_output()
 | 
				
			||||
        # backward attrs
 | 
				
			||||
        self.init_grad_input_output()
 | 
				
			||||
 | 
				
			||||
    def init_input_output(self):
 | 
				
			||||
        self.inputs = {
 | 
				
			||||
            'X': np.random.random(
 | 
				
			||||
                (20, 5)).astype(self.dtype) + 1j * np.random.random(
 | 
				
			||||
                    (20, 5)).astype(self.dtype)
 | 
				
			||||
        }
 | 
				
			||||
        self.outputs = {'Out': numpy_apis[self.op_type](self.inputs['X'])}
 | 
				
			||||
 | 
				
			||||
    def init_grad_input_output(self):
 | 
				
			||||
        self.grad_out = np.ones((20, 5), self.dtype)
 | 
				
			||||
        self.grad_x = np.real(self.grad_out) + 1j * np.zeros(
 | 
				
			||||
            self.grad_out.shape)
 | 
				
			||||
 | 
				
			||||
    def test_check_output(self):
 | 
				
			||||
        self.check_output()
 | 
				
			||||
 | 
				
			||||
    def test_check_grad(self):
 | 
				
			||||
        self.check_grad(
 | 
				
			||||
            ['X'],
 | 
				
			||||
            'Out',
 | 
				
			||||
            user_defined_grads=[self.grad_x],
 | 
				
			||||
            user_defined_grad_outputs=[self.grad_out])
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestImagOp(TestRealOp):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        # switch to static
 | 
				
			||||
        paddle.enable_static()
 | 
				
			||||
        # op test attrs
 | 
				
			||||
        self.op_type = "imag"
 | 
				
			||||
        self.dtype = np.float64
 | 
				
			||||
        self.init_input_output()
 | 
				
			||||
        # backward attrs
 | 
				
			||||
        self.init_grad_input_output()
 | 
				
			||||
 | 
				
			||||
    def init_grad_input_output(self):
 | 
				
			||||
        self.grad_out = np.ones((20, 5), self.dtype)
 | 
				
			||||
        self.grad_x = np.zeros(self.grad_out.shape) + 1j * np.real(
 | 
				
			||||
            self.grad_out)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestRealAPI(unittest.TestCase):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        # switch to static
 | 
				
			||||
        paddle.enable_static()
 | 
				
			||||
        # prepare test attrs
 | 
				
			||||
        self.api = "real"
 | 
				
			||||
        self.dtypes = ["complex64", "complex128"]
 | 
				
			||||
        self.places = [paddle.CPUPlace()]
 | 
				
			||||
        if paddle.is_compiled_with_cuda():
 | 
				
			||||
            self.places.append(paddle.CUDAPlace(0))
 | 
				
			||||
        self._shape = [2, 20, 2, 3]
 | 
				
			||||
 | 
				
			||||
    def test_in_static_mode(self):
 | 
				
			||||
        def init_input_output(dtype):
 | 
				
			||||
            input = np.random.random(self._shape).astype(
 | 
				
			||||
                dtype) + 1j * np.random.random(self._shape).astype(dtype)
 | 
				
			||||
            return {'x': input}, numpy_apis[self.api](input)
 | 
				
			||||
 | 
				
			||||
        for dtype in self.dtypes:
 | 
				
			||||
            input_dict, np_res = init_input_output(dtype)
 | 
				
			||||
            for place in self.places:
 | 
				
			||||
                with static.program_guard(static.Program()):
 | 
				
			||||
                    x = static.data(name="x", shape=self._shape, dtype=dtype)
 | 
				
			||||
                    out = paddle_apis[self.api](x)
 | 
				
			||||
 | 
				
			||||
                    exe = static.Executor(place)
 | 
				
			||||
                    out_value = exe.run(feed=input_dict, fetch_list=[out.name])
 | 
				
			||||
                    self.assertTrue(np.array_equal(np_res, out_value[0]))
 | 
				
			||||
 | 
				
			||||
    def test_in_dynamic_mode(self):
 | 
				
			||||
        for dtype in self.dtypes:
 | 
				
			||||
            input = np.random.random(self._shape).astype(
 | 
				
			||||
                dtype) + 1j * np.random.random(self._shape).astype(dtype)
 | 
				
			||||
            np_res = numpy_apis[self.api](input)
 | 
				
			||||
            for place in self.places:
 | 
				
			||||
                # it is more convenient to use `guard` than `enable/disable_**` here
 | 
				
			||||
                with fluid.dygraph.guard(place):
 | 
				
			||||
                    input_t = paddle.to_tensor(input)
 | 
				
			||||
                    res = paddle_apis[self.api](input_t).numpy()
 | 
				
			||||
                    self.assertTrue(np.array_equal(np_res, res))
 | 
				
			||||
                    res_t = input_t.real().numpy(
 | 
				
			||||
                    ) if self.api is "real" else input_t.imag().numpy()
 | 
				
			||||
                    self.assertTrue(np.array_equal(np_res, res_t))
 | 
				
			||||
 | 
				
			||||
    def test_name_argument(self):
 | 
				
			||||
        with static.program_guard(static.Program()):
 | 
				
			||||
            x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0])
 | 
				
			||||
            out = paddle_apis[self.api](x, name="real_res")
 | 
				
			||||
            self.assertTrue("real_res" in out.name)
 | 
				
			||||
 | 
				
			||||
    def test_dtype_error(self):
 | 
				
			||||
        # in static mode
 | 
				
			||||
        with self.assertRaises(TypeError):
 | 
				
			||||
            with static.program_guard(static.Program()):
 | 
				
			||||
                x = static.data(name="x", shape=self._shape, dtype="float32")
 | 
				
			||||
                out = paddle_apis[self.api](x, name="real_res")
 | 
				
			||||
 | 
				
			||||
        # in dynamic mode
 | 
				
			||||
        with self.assertRaises(RuntimeError):
 | 
				
			||||
            with fluid.dygraph.guard():
 | 
				
			||||
                input = np.random.random(self._shape).astype("float32")
 | 
				
			||||
                input_t = paddle.to_tensor(input)
 | 
				
			||||
                res = paddle_apis[self.api](input_t)
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
class TestImagAPI(TestRealAPI):
 | 
				
			||||
    def setUp(self):
 | 
				
			||||
        # switch to static
 | 
				
			||||
        paddle.enable_static()
 | 
				
			||||
        # prepare test attrs
 | 
				
			||||
        self.api = "imag"
 | 
				
			||||
        self.dtypes = ["complex64", "complex128"]
 | 
				
			||||
        self.places = [paddle.CPUPlace()]
 | 
				
			||||
        if paddle.is_compiled_with_cuda():
 | 
				
			||||
            self.places.append(paddle.CUDAPlace(0))
 | 
				
			||||
        self._shape = [2, 20, 2, 3]
 | 
				
			||||
 | 
				
			||||
 | 
				
			||||
if __name__ == "__main__":
 | 
				
			||||
    unittest.main()
 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue