|
|
|
@ -17,6 +17,7 @@ limitations under the License. */
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
#include <atomic>
|
|
|
|
|
#include <type_traits>
|
|
|
|
|
#include <typeinfo>
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
@ -174,6 +175,15 @@ Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
bool has_temporary_output_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class NOPMaker : public OpProtoAndCheckerMaker {};
|
|
|
|
|
|
|
|
|
|
struct OpInfo {
|
|
|
|
|
std::function creator_;
|
|
|
|
|
std::string grad_op_type_;
|
|
|
|
|
OpProto* proto_;
|
|
|
|
|
OpAttrChecker* checker_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
@ -181,18 +191,25 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
op_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
auto maker = ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
|
static void RegisterOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
PADDLE_ENFORCE(op_info_map().count(op_type) == 0,
|
|
|
|
|
"'%s' is registered more than once.", op_type);
|
|
|
|
|
OpInfo op_info;
|
|
|
|
|
op_info.creator_ = [] { return new OpType; };
|
|
|
|
|
op_info.grad_op_type_ = grad_op_type;
|
|
|
|
|
if (std::type_index(typeid(ProtoMakerType)) !=
|
|
|
|
|
std::type_index(typeid(NOPMaker))) {
|
|
|
|
|
op_info.proto_ = new OpProto;
|
|
|
|
|
op_info.op_checker_ = new OpAttrChecker;
|
|
|
|
|
auto maker = ProtoMakerType(op_info.proto_, op_info.op_checker_);
|
|
|
|
|
maker.Validate();
|
|
|
|
|
*op_proto.mutable_type() = op_type;
|
|
|
|
|
*op_info.proto_->mutable_type() = op_type;
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
op_proto.IsInitialized(),
|
|
|
|
|
op_info.proto_->IsInitialized(),
|
|
|
|
|
"Fail to initialize %s's OpProto, because %s is not initialized",
|
|
|
|
|
op_type, op_proto.InitializationErrorString());
|
|
|
|
|
|
|
|
|
|
op_type, op_info.proto_->InitializationErrorString());
|
|
|
|
|
//======will be refactored in following PRs============//
|
|
|
|
|
VarIndexMaps()[op_type].reset(new VarIndexMap());
|
|
|
|
|
auto& varmap = *VarIndexMaps()[op_type];
|
|
|
|
|
int idx = 0;
|
|
|
|
@ -203,30 +220,26 @@ class OpRegistry {
|
|
|
|
|
for (auto& var : op_proto.outputs()) {
|
|
|
|
|
varmap[var.name()] = idx++;
|
|
|
|
|
}
|
|
|
|
|
//================================================//
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
static void RegisterGradOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
op_creators()[grad_op_type] = [] { return new GradOpType; };
|
|
|
|
|
grad_ops()[op_type] = grad_op_type;
|
|
|
|
|
op_info_map.insert(std::make_pair(op_type, op_info));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
auto it = op_info_map().find(type);
|
|
|
|
|
PADDLE_ENFORCE(it != op_info_map().end(), "'%s' has not been registered.",
|
|
|
|
|
type);
|
|
|
|
|
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
|
auto op = it->second.creator_();
|
|
|
|
|
op->type_ = type;
|
|
|
|
|
op->inputs_ = inputs;
|
|
|
|
|
op->outputs_ = outputs;
|
|
|
|
|
|
|
|
|
|
op->attrs_ = attrs;
|
|
|
|
|
op_checkers().at(type).Check(op->attrs_);
|
|
|
|
|
it->second.checker_->Check(op->attrs_);
|
|
|
|
|
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
|
|
|
|
@ -268,14 +281,9 @@ class OpRegistry {
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpProto>& protos() {
|
|
|
|
|
static std::unordered_map<std::string, OpProto> protos_;
|
|
|
|
|
return protos_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::string>& grad_ops() {
|
|
|
|
|
static std::unordered_map<std::string, std::string> grad_ops_;
|
|
|
|
|
return grad_ops_;
|
|
|
|
|
static std::unordered_map<const std::string, const OpInfo>& op_info_map() {
|
|
|
|
|
static std::unordered_map<const std::string, const OpInfo> op_info_map_;
|
|
|
|
|
return op_info_map_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
|
|
|
|
@ -284,17 +292,7 @@ class OpRegistry {
|
|
|
|
|
return maps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& op_creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> op_creators_;
|
|
|
|
|
return op_creators_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& outname : op->outputs_) {
|
|
|
|
@ -323,16 +321,9 @@ class Registrar {
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
class OpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpRegistrar(const char* op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
class GradOpRegistrar : public Registrar {
|
|
|
|
|
public:
|
|
|
|
|
GradOpRegistrar(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
|
|
|
|
|
OpRegistrar(const char* op_type) { OpRegistrar(op_type, ""); }
|
|
|
|
|
OpRegistrar(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterOp<OpType, ProtoMakerType>(op_type, grad_op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -358,30 +349,21 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class) \
|
|
|
|
|
#define REGISTER_OP(op_type, op_class, op_maker_class, grad_op_type) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_op__##op_type, "REGISTER_OP must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::OpRegistrar<op_class, op_maker_class> \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type); \
|
|
|
|
|
__op_registrar_##op_type##__(#op_type, #grad_op_type); \
|
|
|
|
|
int TouchOpRegistrar_##op_type() { \
|
|
|
|
|
__op_registrar_##op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, grad_op_type, grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##op_type##_##grad_op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be called in global namespace"); \
|
|
|
|
|
static ::paddle::framework::GradOpRegistrar<grad_op_class> \
|
|
|
|
|
__op_gradient_registrar_##op_type##_##grad_op_type##__(#op_type, \
|
|
|
|
|
#grad_op_type); \
|
|
|
|
|
int TouchOpGradientRegistrar_##op_type() { \
|
|
|
|
|
__op_gradient_registrar_##op_type##_##grad_op_type##__.Touch(); \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, op_maker_class, )
|
|
|
|
|
|
|
|
|
|
#define REGISTER_GRADIENT_OP(op_type, op_class) \
|
|
|
|
|
REGISTER_OP(op_type, op_class, ::paddle::framework::NOPMaker, )
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to register OperatorKernel.
|
|
|
|
@ -400,10 +382,12 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Forbid user register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
/*
|
|
|
|
|
#define NO_GRADIENT(op_type) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##op_type##_##op_type##_grad, \
|
|
|
|
|
"NO_GRADIENT must be called in global namespace")
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#define REGISTER_OP_GPU_KERNEL(op_type, ...) \
|
|
|
|
|
REGISTER_OP_KERNEL(op_type, GPU, ::paddle::platform::GPUPlace, __VA_ARGS__)
|
|
|
|
@ -423,23 +407,6 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
static int use_op_itself_##op_type##_ __attribute__((unused)) = \
|
|
|
|
|
TouchOpRegistrar_##op_type()
|
|
|
|
|
|
|
|
|
|
// TODO(fengjiayi): Most ops' gradient op have not been compeleted. So we use
|
|
|
|
|
// `NO_GRAD` to disable micro USE_OP_GRADIENT(op_type). Otherwise the code can't
|
|
|
|
|
// be compiled. `NO_GRAD` should be removed after all gradient ops are
|
|
|
|
|
// compeleted.
|
|
|
|
|
#define NO_GRAD
|
|
|
|
|
#ifndef NO_GRAD
|
|
|
|
|
#define USE_OP_GRADIENT(op_type) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__use_op_gradient_##op_type, \
|
|
|
|
|
"USE_OP_GRADIENT must be called in global namespace"); \
|
|
|
|
|
extern int TouchOpGradientRegistrar_##op_type(); \
|
|
|
|
|
static int use_op_gradient_##op_type##_ __attribute__((unused)) = \
|
|
|
|
|
TouchOpGradientRegistrar_##op_type()
|
|
|
|
|
#else
|
|
|
|
|
#define USE_OP_GRADIENT(op_type)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define USE_OP_DEVICE_KERNEL(op_type, DEVICE_TYPE) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__use_op_kernel_##op_type##_##DEVICE_TYPE##__, \
|
|
|
|
@ -459,18 +426,13 @@ class OpKernelRegistrar : public Registrar {
|
|
|
|
|
USE_OP_DEVICE_KERNEL(op_type, GPU)
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
#define USE_NO_GRAD_OP(op_type) \
|
|
|
|
|
USE_OP_ITSELF(op_type); \
|
|
|
|
|
USE_OP_KERNEL(op_type)
|
|
|
|
|
|
|
|
|
|
#define USE_CPU_OP(op_type) \
|
|
|
|
|
#define USE_CPU_ONLY_OP(op_type) \
|
|
|
|
|
USE_OP_ITSELF(op_type); \
|
|
|
|
|
USE_OP_DEVICE_KERNEL(op_type, CPU); \
|
|
|
|
|
USE_OP_GRADIENT(op_type)
|
|
|
|
|
USE_OP_DEVICE_KERNEL(op_type, CPU);
|
|
|
|
|
|
|
|
|
|
#define USE_OP(op_type) \
|
|
|
|
|
USE_NO_GRAD_OP(op_type); \
|
|
|
|
|
USE_OP_GRADIENT(op_type)
|
|
|
|
|
USE_OP_ITSELF(op_type); \
|
|
|
|
|
USE_OP_KERNEL(op_type)
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|