|
|
|
@ -1,6 +1,7 @@
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include "paddle/framework/attr_checker.h"
|
|
|
|
|
#include "paddle/framework/op_desc.pb.h"
|
|
|
|
|
#include "paddle/framework/op_proto.pb.h"
|
|
|
|
@ -101,8 +102,11 @@ class OpRegistry {
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
|
PADDLE_ENFORCE(op_proto.IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto !", op_type);
|
|
|
|
|
*op_proto.mutable_type() = op_type;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
op_proto.IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_proto.InitializationErrorString());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorBase* CreateOp(const OpDesc& op_desc) {
|
|
|
|
@ -143,18 +147,72 @@ class OpRegistry {
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
class OpRegisterHelper {
|
|
|
|
|
public:
|
|
|
|
|
OpRegisterHelper(std::string op_type) {
|
|
|
|
|
OpRegisterHelper(const char* op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP(type, op_class, op_maker_class) \
|
|
|
|
|
class op_class##Register { \
|
|
|
|
|
private: \
|
|
|
|
|
const static OpRegisterHelper<op_class, op_maker_class> reg; \
|
|
|
|
|
#define STATIC_ASSERT_GLOBAL_NAMESPACE(uniq_name, msg) \
|
|
|
|
|
struct __test_global_namespace_##uniq_name##__ {}; \
|
|
|
|
|
static_assert(std::is_same<::__test_global_namespace_##uniq_name##__, \
|
|
|
|
|
__test_global_namespace_##uniq_name##__>::value, \
|
|
|
|
|
msg)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP(__op_type, __op_class, __op_maker_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE(__reg_op__##__op_type, \
|
|
|
|
|
"REGISTER_OP must be in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpRegisterHelper<__op_class, __op_maker_class> \
|
|
|
|
|
__op_register_##__op_type##__(#__op_type); \
|
|
|
|
|
int __op_register_##__op_type##_handle__() { return 0; }
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_KERNEL(type, GPU_OR_CPU, PlaceType, KernelType) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op_kernel_##type##_##GPU_OR_CPU##__, \
|
|
|
|
|
"REGISTER_OP_KERNEL must be in global namespace"); \
|
|
|
|
|
struct __op_kernel_register__##type##__ { \
|
|
|
|
|
__op_kernel_register__##type##__() { \
|
|
|
|
|
::paddle::framework::OperatorWithKernel::OpKernelKey key; \
|
|
|
|
|
key.place_ = PlaceType(); \
|
|
|
|
|
::paddle::framework::OperatorWithKernel::AllOpKernels()[#type][key] \
|
|
|
|
|
.reset(new KernelType()); \
|
|
|
|
|
} \
|
|
|
|
|
}; \
|
|
|
|
|
const OpRegisterHelper<op_class, op_maker_class> op_class##Register::reg( \
|
|
|
|
|
#type)
|
|
|
|
|
static __op_kernel_register__##type##__ __reg_kernel_##type##__; \
|
|
|
|
|
int __op_kernel_register_##type##_handle_##GPU_OR_CPU##__() { return 0; }
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_GPU_KERNEL(type, KernelType) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, GPU, ::paddle::platform::GPUPlace, KernelType)
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_CPU_KERNEL(type, KernelType) \
|
|
|
|
|
REGISTER_OP_KERNEL(type, CPU, ::paddle::platform::CPUPlace, KernelType)
|
|
|
|
|
|
|
|
|
|
#define USE_OP_WITHOUT_KERNEL(op_type) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__use_op_without_kernel_##op_type, \
|
|
|
|
|
"USE_OP_WITHOUT_KERNEL must be in global namespace"); \
|
|
|
|
|
extern int __op_register_##op_type##_handle__(); \
|
|
|
|
|
static int __use_op_ptr_##op_type##_without_kernel__ \
|
|
|
|
|
__attribute__((unused)) = __op_register_##op_type##_handle__()
|
|
|
|
|
|
|
|
|
|
#define USE_OP_KERNEL(op_type, CPU_OR_GPU) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE(__use_op_kernel_##op_type##_##CPU_OR_GPU##__, \
|
|
|
|
|
"USE_OP_KERNEL must be in global namespace"); \
|
|
|
|
|
extern int __op_kernel_register_##op_type##_handle_##CPU_OR_GPU##__(); \
|
|
|
|
|
static int __use_op_ptr_##op_type##_##CPU_OR_GPU##_kernel__ \
|
|
|
|
|
__attribute__((unused)) = \
|
|
|
|
|
__op_kernel_register_##op_type##_handle_##CPU_OR_GPU##__()
|
|
|
|
|
|
|
|
|
|
#ifdef PADDLE_ONLY_CPU
|
|
|
|
|
#define USE_OP(op_type) \
|
|
|
|
|
USE_OP_WITHOUT_KERNEL(op_type); \
|
|
|
|
|
USE_OP_KERNEL(op_type, CPU);
|
|
|
|
|
|
|
|
|
|
#else
|
|
|
|
|
#define USE_OP(op_type) \
|
|
|
|
|
USE_OP_WITHOUT_KERNEL(op_type); \
|
|
|
|
|
USE_OP_KERNEL(op_type, CPU); \
|
|
|
|
|
USE_OP_KERNEL(op_type, GPU)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|