add complex64 and complex128 type; add +-*/@ and slice opreator for c… (#29199)

* add complex64 and complex128 type; add +-*/@ and slice opreator for complex types

* add test cases for complex elementwise, matmul and getitem unittest

* add test cases for complex types

* add test cases for complex matmul unittest
release/2.0-rc1
chentianyu03 5 years ago committed by GitHub
parent cc9c619679
commit 8f45d14263
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -18,6 +18,8 @@ limitations under the License. */
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
@ -25,6 +27,8 @@ namespace paddle {
namespace platform {
struct bfloat16;
struct float16;
struct complex64;
struct complex128;
} // namespace platform
} // namespace paddle
@ -45,23 +49,27 @@ struct DataTypeTrait<void> {
#define _ForEachDataTypeHelper_(callback, cpp_type, proto_type) \
callback(cpp_type, ::paddle::framework::proto::VarType::proto_type);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8)
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64);
#define _ForEachDataType_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::float16, FP16); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::bfloat16, BF16); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, bool, BOOL); \
_ForEachDataTypeHelper_(callback, uint8_t, UINT8); \
_ForEachDataTypeHelper_(callback, int16_t, INT16); \
_ForEachDataTypeHelper_(callback, int8_t, INT8); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
#define _ForEachDataTypeSmall_(callback) \
_ForEachDataTypeHelper_(callback, float, FP32); \
_ForEachDataTypeHelper_(callback, double, FP64); \
_ForEachDataTypeHelper_(callback, int, INT32); \
_ForEachDataTypeHelper_(callback, int64_t, INT64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex64, COMPLEX64); \
_ForEachDataTypeHelper_(callback, ::paddle::platform::complex128, COMPLEX128);
// For the use of thrust, as index-type elements can be only integers.
#define _ForEachDataTypeTiny_(callback) \

@ -169,6 +169,10 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#pragma omp declare reduction(+ : paddle::platform::bfloat16 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex64 : omp_out += \
omp_in)
#pragma omp declare reduction(+ : paddle::platform::complex128 : omp_out += \
omp_in)
#endif
template <typename T>
@ -222,6 +226,37 @@ void CheckNanInf<paddle::platform::bfloat16>(
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::complex64>(
const paddle::platform::complex64* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) {
paddle::platform::complex64 sum(0.0, 0.0);
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
template <>
void CheckNanInf<paddle::platform::complex128>(
const paddle::platform::complex128* value, const size_t numel,
int print_num, const std::string& op_type, const std::string& var_name) {
paddle::platform::complex128 sum(0.0, 0.0);
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#endif
template <>

@ -116,6 +116,8 @@ message VarType {
UINT8 = 20;
INT8 = 21;
BF16 = 22;
COMPLEX64 = 23;
COMPLEX128 = 24;
// Other types that may need additional descriptions
LOD_TENSOR = 7;

@ -22,6 +22,8 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle {
@ -990,6 +992,40 @@ std::ostream& print_tensor(std::ostream& os, const framework::Tensor& tensor) {
return os;
}
template <>
std::ostream& print_tensor<paddle::platform::complex64>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex64>();
auto element_num = tensor.numel();
os << " - data: [";
if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j";
}
}
os << "]";
return os;
}
template <>
std::ostream& print_tensor<paddle::platform::complex128>(
std::ostream& os, const framework::Tensor& tensor) {
auto inspect = tensor.data<paddle::platform::complex128>();
auto element_num = tensor.numel();
os << " - data: [";
if (element_num > 0) {
os << signed(inspect[0].real) << signed(inspect[0].imag) << "j";
for (int j = 1; j < element_num; ++j) {
os << signed(inspect[j].real) << signed(inspect[j].imag) << "j";
}
}
os << "]";
return os;
}
std::ostream& operator<<(std::ostream& os, const Tensor& t) {
os << " - place: " << t.place() << "\n";
os << " - shape: [" << t.dims() << "]\n";

@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace framework {
@ -128,13 +130,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
@ -144,7 +154,11 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
// A specialization elementwise_add operator, used in gradient accumulation with
// inplace addto.
@ -159,4 +173,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseAddKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
@ -95,26 +97,35 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddGradKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_add_grad_grad,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
plat::complex64>,
ops::ElementwiseAddDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
grad_add, ops::ElementwiseAddKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseAddKernel<plat::CUDADeviceContext, plat::complex128>);

@ -16,6 +16,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
@ -130,13 +132,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_div_grad_grad,
@ -147,4 +157,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_div_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
@ -102,7 +104,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, float>,
@ -110,8 +116,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_div_grad_grad,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
@ -123,4 +132,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseDivDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);

@ -16,6 +16,8 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace operators {
@ -130,13 +132,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
@ -146,4 +156,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseMulDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_mul_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
@ -100,19 +102,26 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseMulKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>);
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex64>,
ops::ElementwiseMulGradKernel<plat::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_mul_grad_grad,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, float>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, double>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, int64_t>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext, plat::float16>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::float16>);
plat::complex64>,
ops::ElementwiseMulDoubleGradKernel<plat::CUDADeviceContext,
plat::complex128>);

@ -17,6 +17,8 @@ limitations under the License. */
#include <string>
#include "paddle/fluid/operators/elementwise/elementwise_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace paddle {
namespace framework {
@ -125,13 +127,21 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
@ -141,4 +151,8 @@ REGISTER_OP_CPU_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -14,6 +14,8 @@ limitations under the License. */
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/elementwise/elementwise_sub_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
@ -99,7 +101,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>);
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, float>,
@ -107,8 +113,11 @@ REGISTER_OP_CUDA_KERNEL(
paddle::platform::float16>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
elementwise_sub_grad_grad,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
@ -118,4 +127,8 @@ REGISTER_OP_CUDA_KERNEL(
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
int64_t>);
int64_t>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::ElementwiseSubDoubleGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -65,14 +65,16 @@ class SplitFunctor {
} // namespace operators
} // namespace paddle
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16)
#define FOR_ALL_TYPES(macro) \
macro(int); \
macro(float); \
macro(double); \
macro(bool); \
macro(int64_t); \
macro(int16_t); \
macro(uint8_t); \
macro(int8_t); \
macro(::paddle::platform::float16); \
macro(::paddle::platform::bfloat16); \
macro(::paddle::platform::complex64); \
macro(::paddle::platform::complex128)

@ -44,6 +44,8 @@ template struct SetConstant<platform::CPUDeviceContext, int>;
template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex64>;
template struct SetConstant<platform::CPUDeviceContext, platform::complex128>;
#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
@ -54,19 +56,23 @@ template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
#endif
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>;
#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::bfloat16, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, float, RANK>; \
template struct Transpose<platform::CPUDeviceContext, double, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, bool, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int16_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, uint8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex64, \
RANK>; \
template struct Transpose<platform::CPUDeviceContext, platform::complex128, \
RANK>;
DEFINE_CPU_TRANS(1);
DEFINE_CPU_TRANS(2);
@ -117,6 +123,8 @@ DEFINE_CPU_TRANS_NORMAL(bool);
DEFINE_CPU_TRANS_NORMAL(int16_t);
DEFINE_CPU_TRANS_NORMAL(uint8_t);
DEFINE_CPU_TRANS_NORMAL(int8_t);
DEFINE_CPU_TRANS_NORMAL(platform::complex64);
DEFINE_CPU_TRANS_NORMAL(platform::complex128);
struct TensorSetConstantCPU {
TensorSetConstantCPU(framework::Tensor* tensor, float value)

@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/math_function_impl.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
@ -27,6 +29,8 @@ namespace math {
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
template struct SetConstant<platform::CUDADeviceContext, platform::float16>;
template struct SetConstant<platform::CUDADeviceContext, float>;
@ -34,15 +38,19 @@ template struct SetConstant<platform::CUDADeviceContext, double>;
template struct SetConstant<platform::CUDADeviceContext, int>;
template struct SetConstant<platform::CUDADeviceContext, int64_t>;
template struct SetConstant<platform::CUDADeviceContext, bool>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex64>;
template struct SetConstant<platform::CUDADeviceContext, platform::complex128>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>;
#define DEFINE_GPU_TRANS(RANK) \
template struct Transpose<platform::CUDADeviceContext, float, RANK>; \
template struct Transpose<platform::CUDADeviceContext, double, RANK>; \
template struct Transpose<platform::CUDADeviceContext, float16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, bfloat16, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int8_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int32_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, int64_t, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex64, RANK>; \
template struct Transpose<platform::CUDADeviceContext, complex128, RANK>;
DEFINE_GPU_TRANS(1);
DEFINE_GPU_TRANS(2);
@ -132,6 +140,8 @@ DEFINE_GPU_TRANS_NORMAL(bool);
DEFINE_GPU_TRANS_NORMAL(int16_t);
DEFINE_GPU_TRANS_NORMAL(uint8_t);
DEFINE_GPU_TRANS_NORMAL(int8_t);
DEFINE_GPU_TRANS_NORMAL(complex64);
DEFINE_GPU_TRANS_NORMAL(complex128);
struct TensorSetConstantGPU {
TensorSetConstantGPU(const platform::DeviceContext& context,

@ -168,9 +168,17 @@ REGISTER_OPERATOR(matmul_v2_grad, ops::MatMulV2OpGrad);
REGISTER_OP_CPU_KERNEL(
matmul_v2, ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>);
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::MatMulV2Kernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
matmul_v2_grad,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, float>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>);
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext, double>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::MatMulV2GradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -20,9 +20,13 @@ namespace plf = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
matmul_v2, ops::MatMulV2Kernel<plf::CUDADeviceContext, float>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, double>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>);
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2Kernel<plf::CUDADeviceContext, plf::complex128>);
REGISTER_OP_CUDA_KERNEL(
matmul_v2_grad, ops::MatMulV2GradKernel<plf::CUDADeviceContext, float>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, double>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>);
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::float16>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex64>,
ops::MatMulV2GradKernel<plf::CUDADeviceContext, plf::complex128>);

@ -424,10 +424,18 @@ REGISTER_OP_CPU_KERNEL(
slice, ops::SliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>);
ops::SliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::SliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
slice_grad, ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::SliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::SliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -23,7 +23,9 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::complex64>,
ops::SliceKernel<paddle::platform::CUDADeviceContext, plat::complex128>);
REGISTER_OP_CUDA_KERNEL(
slice_grad,
@ -31,4 +33,7 @@ REGISTER_OP_CUDA_KERNEL(
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>);
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::float16>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext, plat::complex64>,
ops::SliceGradKernel<paddle::platform::CUDADeviceContext,
plat::complex128>);

@ -327,11 +327,19 @@ REGISTER_OP_CPU_KERNEL(
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>);
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::StridedSliceKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CPU_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>);
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex64>,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext,
paddle::platform::complex128>);

@ -13,6 +13,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/strided_slice_op.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
@ -20,11 +22,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>);
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::StridedSliceKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);
REGISTER_OP_CUDA_KERNEL(
strided_slice_grad,
ops::StridedSliceGradKernel<paddle::platform::CPUDeviceContext, int>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, int64_t>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex64>,
ops::StridedSliceGradKernel<paddle::platform::CUDADeviceContext,
paddle::platform::complex128>);

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save