[Complex] Add real & imag op and api for complex tensor (#29672)
* add complex real op & api & unittest * add imag op & api & unittest * refactor op impl * revert simplify writing due to complile failed * polish details * polish grad op coderevert-31562-mean
parent
9eff1a674f
commit
6cfa59de1b
@ -0,0 +1,106 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/imag_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class ImagOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Imag");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Imag");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class ImagOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor), The input tensor of imag op.");
|
||||
AddOutput("Out", "(Tensor), The output tensor of imag op.");
|
||||
AddComment(R"DOC(
|
||||
Imag Operator.
|
||||
|
||||
This operator is used to get a new tensor containing imaginary values
|
||||
from a tensor with complex data type.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class ImagGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
"Out@Grad", "ImagGrad");
|
||||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
||||
"X@Grad", "ImagGrad");
|
||||
|
||||
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto dtype = OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out"));
|
||||
auto complex_dtype = framework::ToComplexType(dtype);
|
||||
return framework::OpKernelType(complex_dtype, ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class ImagGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
void Apply(GradOpPtr<T> grad_op) const override {
|
||||
grad_op->SetType("imag_grad");
|
||||
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_INPLACE_OP_INFERER(ImagOpInplaceInferer, {"X", "Out"});
|
||||
DECLARE_INPLACE_OP_INFERER(ImagGradOpInplaceInferer,
|
||||
{framework::GradVarName("Out"),
|
||||
framework::GradVarName("X")});
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(imag, ops::ImagOp, ops::ImagOpMaker,
|
||||
ops::ImagGradOpMaker<paddle::framework::OpDesc>,
|
||||
ops::ImagGradOpMaker<paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(imag_grad, ops::ImagGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(imag, ops::ImagKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::ImagKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex128>);
|
||||
REGISTER_OP_CPU_KERNEL(imag_grad,
|
||||
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::ImagGradKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex128>);
|
@ -0,0 +1,28 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/imag_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(imag,
|
||||
ops::ImagKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::ImagKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex128>);
|
||||
REGISTER_OP_CUDA_KERNEL(imag_grad,
|
||||
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::ImagGradKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex128>);
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/complex_functors.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ImagKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
|
||||
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto numel = x->numel();
|
||||
auto* x_data = x->data<T>();
|
||||
auto* out_data = out->mutable_data<math::Real<T>>(
|
||||
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(math::Real<T>)));
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
||||
math::ImagFunctor<T> functor(x_data, out_data, numel);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class ImagGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
const framework::Tensor* d_out =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
framework::Tensor* d_x =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
|
||||
auto numel = d_out->numel();
|
||||
auto* dout_data = d_out->data<math::Real<T>>();
|
||||
auto* dx_data = d_x->mutable_data<T>(
|
||||
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
||||
math::ImagToComplexFunctor<T> functor(dout_data, dx_data, numel);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,140 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "paddle/fluid/platform/complex128.h"
|
||||
#include "paddle/fluid/platform/complex64.h"
|
||||
#include "paddle/fluid/platform/hostdevice.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
namespace math {
|
||||
|
||||
template <bool B, typename T>
|
||||
struct cond {
|
||||
static constexpr bool value = B;
|
||||
using type = T;
|
||||
};
|
||||
|
||||
template <bool B, typename TrueF, typename FalseF>
|
||||
struct eval_if {
|
||||
using type = typename TrueF::type;
|
||||
};
|
||||
|
||||
template <typename TrueF, typename FalseF>
|
||||
struct eval_if<false, TrueF, FalseF> {
|
||||
using type = typename FalseF::type;
|
||||
};
|
||||
|
||||
template <bool B, typename T, typename F>
|
||||
using eval_if_t = typename eval_if<B, T, F>::type;
|
||||
|
||||
template <typename Head, typename... Tail>
|
||||
struct select {
|
||||
using type = eval_if_t<Head::value, Head, select<Tail...>>;
|
||||
};
|
||||
|
||||
template <typename Head, typename... Tail>
|
||||
using select_t = typename select<Head, Tail...>::type;
|
||||
|
||||
template <typename T>
|
||||
using Real =
|
||||
select_t<cond<std::is_same<T, platform::complex64>::value, float>,
|
||||
cond<std::is_same<T, platform::complex128>::value, double>, T>;
|
||||
|
||||
template <typename T, typename RealT>
|
||||
using Complex = typename std::enable_if<!std::is_same<T, RealT>::value>::type;
|
||||
|
||||
// There are no NoComplex cases now, implement later if needed
|
||||
template <typename T, typename RealT>
|
||||
using NoComplex = typename std::enable_if<std::is_same<T, RealT>::value>::type;
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
struct RealFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct RealFunctor<T, Complex<T, Real<T>>> {
|
||||
public:
|
||||
RealFunctor(const T* input, Real<T>* output, int64_t numel)
|
||||
: input_(input), output_(output), numel_(numel) {}
|
||||
|
||||
HOSTDEVICE void operator()(int64_t idx) const {
|
||||
output_[idx] = input_[idx].real;
|
||||
}
|
||||
|
||||
private:
|
||||
const T* input_;
|
||||
Real<T>* output_;
|
||||
int64_t numel_;
|
||||
};
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
struct ImagFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct ImagFunctor<T, Complex<T, Real<T>>> {
|
||||
ImagFunctor(const T* input, Real<T>* output, int64_t numel)
|
||||
: input_(input), output_(output), numel_(numel) {}
|
||||
|
||||
HOSTDEVICE void operator()(int64_t idx) const {
|
||||
output_[idx] = input_[idx].imag;
|
||||
}
|
||||
|
||||
const T* input_;
|
||||
Real<T>* output_;
|
||||
int64_t numel_;
|
||||
};
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
struct RealToComplexFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct RealToComplexFunctor<T, Complex<T, Real<T>>> {
|
||||
RealToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
|
||||
: input_(input), output_(output), numel_(numel) {}
|
||||
|
||||
HOSTDEVICE void operator()(int64_t idx) const {
|
||||
output_[idx].real = input_[idx];
|
||||
output_[idx].imag = 0;
|
||||
}
|
||||
|
||||
const Real<T>* input_;
|
||||
T* output_;
|
||||
int64_t numel_;
|
||||
};
|
||||
|
||||
template <typename T, typename Enable = void>
|
||||
struct ImagToComplexFunctor;
|
||||
|
||||
template <typename T>
|
||||
struct ImagToComplexFunctor<T, Complex<T, Real<T>>> {
|
||||
ImagToComplexFunctor(const Real<T>* input, T* output, int64_t numel)
|
||||
: input_(input), output_(output), numel_(numel) {}
|
||||
|
||||
HOSTDEVICE void operator()(int64_t idx) const {
|
||||
output_[idx].real = 0;
|
||||
output_[idx].imag = input_[idx];
|
||||
}
|
||||
|
||||
const Real<T>* input_;
|
||||
T* output_;
|
||||
int64_t numel_;
|
||||
};
|
||||
|
||||
} // namespace math
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,105 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/real_op.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
class RealOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "Real");
|
||||
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Real");
|
||||
|
||||
auto x_dims = ctx->GetInputDim("X");
|
||||
ctx->SetOutputDim("Out", x_dims);
|
||||
ctx->ShareLoD("X", "Out");
|
||||
}
|
||||
};
|
||||
|
||||
class RealOpMaker : public framework::OpProtoAndCheckerMaker {
|
||||
public:
|
||||
void Make() override {
|
||||
AddInput("X", "(Tensor), The input tensor of real op.");
|
||||
AddOutput("Out", "(Tensor), The output tensor of real op.");
|
||||
AddComment(R"DOC(
|
||||
Real Operator.
|
||||
|
||||
This operator is used to get a new tensor containing real values
|
||||
from a tensor with complex data type.
|
||||
|
||||
)DOC");
|
||||
}
|
||||
};
|
||||
|
||||
class RealGradOp : public framework::OperatorWithKernel {
|
||||
public:
|
||||
using framework::OperatorWithKernel::OperatorWithKernel;
|
||||
void InferShape(framework::InferShapeContext* ctx) const override {
|
||||
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Out")), "Input",
|
||||
"Out@Grad", "RealGrad");
|
||||
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
|
||||
"X@Grad", "RealGrad");
|
||||
|
||||
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
|
||||
ctx->SetOutputDim(framework::GradVarName("X"), dout_dims);
|
||||
}
|
||||
|
||||
protected:
|
||||
framework::OpKernelType GetExpectedKernelType(
|
||||
const framework::ExecutionContext& ctx) const override {
|
||||
auto dtype = OperatorWithKernel::IndicateVarDataType(
|
||||
ctx, framework::GradVarName("Out"));
|
||||
auto complex_dtype = framework::ToComplexType(dtype);
|
||||
return framework::OpKernelType(complex_dtype, ctx.GetPlace());
|
||||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
class RealGradOpMaker : public framework::SingleGradOpMaker<T> {
|
||||
public:
|
||||
using framework::SingleGradOpMaker<T>::SingleGradOpMaker;
|
||||
void Apply(GradOpPtr<T> grad_op) const override {
|
||||
grad_op->SetType("real_grad");
|
||||
grad_op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out"));
|
||||
grad_op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
|
||||
}
|
||||
};
|
||||
|
||||
DECLARE_INPLACE_OP_INFERER(RealOpInplaceInferer, {"X", "Out"});
|
||||
DECLARE_INPLACE_OP_INFERER(RealGradOpInplaceInferer,
|
||||
{framework::GradVarName("Out"),
|
||||
framework::GradVarName("X")});
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OPERATOR(real, ops::RealOp, ops::RealOpMaker,
|
||||
ops::RealGradOpMaker<::paddle::framework::OpDesc>,
|
||||
ops::RealGradOpMaker<::paddle::imperative::OpBase>);
|
||||
REGISTER_OPERATOR(real_grad, ops::RealGradOp);
|
||||
|
||||
REGISTER_OP_CPU_KERNEL(real, ops::RealKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::RealKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex128>);
|
||||
REGISTER_OP_CPU_KERNEL(real_grad,
|
||||
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::RealGradKernel<paddle::platform::CPUDeviceContext,
|
||||
paddle::platform::complex128>);
|
@ -0,0 +1,28 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#include "paddle/fluid/operators/real_op.h"
|
||||
|
||||
namespace ops = paddle::operators;
|
||||
|
||||
REGISTER_OP_CUDA_KERNEL(real,
|
||||
ops::RealKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::RealKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex128>);
|
||||
REGISTER_OP_CUDA_KERNEL(real_grad,
|
||||
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex64>,
|
||||
ops::RealGradKernel<paddle::platform::CUDADeviceContext,
|
||||
paddle::platform::complex128>);
|
@ -0,0 +1,66 @@
|
||||
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License. */
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "paddle/fluid/framework/data_type.h"
|
||||
#include "paddle/fluid/framework/op_registry.h"
|
||||
#include "paddle/fluid/operators/math/complex_functors.h"
|
||||
#include "paddle/fluid/platform/for_range.h"
|
||||
|
||||
namespace paddle {
|
||||
namespace operators {
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RealKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
const framework::Tensor* x = ctx.Input<framework::Tensor>("X");
|
||||
framework::Tensor* out = ctx.Output<framework::Tensor>("Out");
|
||||
|
||||
auto numel = x->numel();
|
||||
auto* x_data = x->data<T>();
|
||||
auto* out_data = out->mutable_data<math::Real<T>>(
|
||||
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(math::Real<T>)));
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
||||
math::RealFunctor<T> functor(x_data, out_data, numel);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
template <typename DeviceContext, typename T>
|
||||
class RealGradKernel : public framework::OpKernel<T> {
|
||||
public:
|
||||
void Compute(const framework::ExecutionContext& ctx) const {
|
||||
const framework::Tensor* d_out =
|
||||
ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
|
||||
framework::Tensor* d_x =
|
||||
ctx.Output<framework::Tensor>(framework::GradVarName("X"));
|
||||
|
||||
auto numel = d_out->numel();
|
||||
auto* dout_data = d_out->data<math::Real<T>>();
|
||||
auto* dx_data = d_x->mutable_data<T>(
|
||||
ctx.GetPlace(), static_cast<size_t>(numel * sizeof(T)));
|
||||
|
||||
auto& dev_ctx = ctx.template device_context<DeviceContext>();
|
||||
platform::ForRange<DeviceContext> for_range(dev_ctx, numel);
|
||||
math::RealToComplexFunctor<T> functor(dout_data, dx_data, numel);
|
||||
for_range(functor);
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace operators
|
||||
} // namespace paddle
|
@ -0,0 +1,167 @@
|
||||
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import unittest
|
||||
import numpy as np
|
||||
|
||||
import paddle
|
||||
import paddle.fluid as fluid
|
||||
import paddle.static as static
|
||||
from op_test import OpTest
|
||||
|
||||
numpy_apis = {
|
||||
"real": np.real,
|
||||
"imag": np.imag,
|
||||
}
|
||||
|
||||
paddle_apis = {
|
||||
"real": paddle.real,
|
||||
"imag": paddle.imag,
|
||||
}
|
||||
|
||||
|
||||
class TestRealOp(OpTest):
|
||||
def setUp(self):
|
||||
# switch to static
|
||||
paddle.enable_static()
|
||||
# op test attrs
|
||||
self.op_type = "real"
|
||||
self.dtype = np.float64
|
||||
self.init_input_output()
|
||||
# backward attrs
|
||||
self.init_grad_input_output()
|
||||
|
||||
def init_input_output(self):
|
||||
self.inputs = {
|
||||
'X': np.random.random(
|
||||
(20, 5)).astype(self.dtype) + 1j * np.random.random(
|
||||
(20, 5)).astype(self.dtype)
|
||||
}
|
||||
self.outputs = {'Out': numpy_apis[self.op_type](self.inputs['X'])}
|
||||
|
||||
def init_grad_input_output(self):
|
||||
self.grad_out = np.ones((20, 5), self.dtype)
|
||||
self.grad_x = np.real(self.grad_out) + 1j * np.zeros(
|
||||
self.grad_out.shape)
|
||||
|
||||
def test_check_output(self):
|
||||
self.check_output()
|
||||
|
||||
def test_check_grad(self):
|
||||
self.check_grad(
|
||||
['X'],
|
||||
'Out',
|
||||
user_defined_grads=[self.grad_x],
|
||||
user_defined_grad_outputs=[self.grad_out])
|
||||
|
||||
|
||||
class TestImagOp(TestRealOp):
|
||||
def setUp(self):
|
||||
# switch to static
|
||||
paddle.enable_static()
|
||||
# op test attrs
|
||||
self.op_type = "imag"
|
||||
self.dtype = np.float64
|
||||
self.init_input_output()
|
||||
# backward attrs
|
||||
self.init_grad_input_output()
|
||||
|
||||
def init_grad_input_output(self):
|
||||
self.grad_out = np.ones((20, 5), self.dtype)
|
||||
self.grad_x = np.zeros(self.grad_out.shape) + 1j * np.real(
|
||||
self.grad_out)
|
||||
|
||||
|
||||
class TestRealAPI(unittest.TestCase):
|
||||
def setUp(self):
|
||||
# switch to static
|
||||
paddle.enable_static()
|
||||
# prepare test attrs
|
||||
self.api = "real"
|
||||
self.dtypes = ["complex64", "complex128"]
|
||||
self.places = [paddle.CPUPlace()]
|
||||
if paddle.is_compiled_with_cuda():
|
||||
self.places.append(paddle.CUDAPlace(0))
|
||||
self._shape = [2, 20, 2, 3]
|
||||
|
||||
def test_in_static_mode(self):
|
||||
def init_input_output(dtype):
|
||||
input = np.random.random(self._shape).astype(
|
||||
dtype) + 1j * np.random.random(self._shape).astype(dtype)
|
||||
return {'x': input}, numpy_apis[self.api](input)
|
||||
|
||||
for dtype in self.dtypes:
|
||||
input_dict, np_res = init_input_output(dtype)
|
||||
for place in self.places:
|
||||
with static.program_guard(static.Program()):
|
||||
x = static.data(name="x", shape=self._shape, dtype=dtype)
|
||||
out = paddle_apis[self.api](x)
|
||||
|
||||
exe = static.Executor(place)
|
||||
out_value = exe.run(feed=input_dict, fetch_list=[out.name])
|
||||
self.assertTrue(np.array_equal(np_res, out_value[0]))
|
||||
|
||||
def test_in_dynamic_mode(self):
|
||||
for dtype in self.dtypes:
|
||||
input = np.random.random(self._shape).astype(
|
||||
dtype) + 1j * np.random.random(self._shape).astype(dtype)
|
||||
np_res = numpy_apis[self.api](input)
|
||||
for place in self.places:
|
||||
# it is more convenient to use `guard` than `enable/disable_**` here
|
||||
with fluid.dygraph.guard(place):
|
||||
input_t = paddle.to_tensor(input)
|
||||
res = paddle_apis[self.api](input_t).numpy()
|
||||
self.assertTrue(np.array_equal(np_res, res))
|
||||
res_t = input_t.real().numpy(
|
||||
) if self.api is "real" else input_t.imag().numpy()
|
||||
self.assertTrue(np.array_equal(np_res, res_t))
|
||||
|
||||
def test_name_argument(self):
|
||||
with static.program_guard(static.Program()):
|
||||
x = static.data(name="x", shape=self._shape, dtype=self.dtypes[0])
|
||||
out = paddle_apis[self.api](x, name="real_res")
|
||||
self.assertTrue("real_res" in out.name)
|
||||
|
||||
def test_dtype_error(self):
|
||||
# in static mode
|
||||
with self.assertRaises(TypeError):
|
||||
with static.program_guard(static.Program()):
|
||||
x = static.data(name="x", shape=self._shape, dtype="float32")
|
||||
out = paddle_apis[self.api](x, name="real_res")
|
||||
|
||||
# in dynamic mode
|
||||
with self.assertRaises(RuntimeError):
|
||||
with fluid.dygraph.guard():
|
||||
input = np.random.random(self._shape).astype("float32")
|
||||
input_t = paddle.to_tensor(input)
|
||||
res = paddle_apis[self.api](input_t)
|
||||
|
||||
|
||||
class TestImagAPI(TestRealAPI):
|
||||
def setUp(self):
|
||||
# switch to static
|
||||
paddle.enable_static()
|
||||
# prepare test attrs
|
||||
self.api = "imag"
|
||||
self.dtypes = ["complex64", "complex128"]
|
||||
self.places = [paddle.CPUPlace()]
|
||||
if paddle.is_compiled_with_cuda():
|
||||
self.places.append(paddle.CUDAPlace(0))
|
||||
self._shape = [2, 20, 2, 3]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in new issue