|
|
|
@ -1,3 +1,17 @@
|
|
|
|
|
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
|
|
|
|
|
|
|
|
|
|
Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
you may not use this file except in compliance with the License.
|
|
|
|
|
You may obtain a copy of the License at
|
|
|
|
|
|
|
|
|
|
http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
|
|
|
|
Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
See the License for the specific language governing permissions and
|
|
|
|
|
limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#pragma once
|
|
|
|
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
@ -6,9 +20,9 @@
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "paddle/framework/attr_checker.h"
|
|
|
|
|
#include "paddle/framework/grad_op_builder.h"
|
|
|
|
|
#include "paddle/framework/op_desc.pb.h"
|
|
|
|
|
#include "paddle/framework/op_proto.pb.h"
|
|
|
|
|
#include "paddle/framework/operator.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace framework {
|
|
|
|
@ -73,25 +87,29 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
void AddInput(const std::string& name, const std::string& comment,
|
|
|
|
|
bool multiple = false) {
|
|
|
|
|
bool multiple = false, bool ignore_gradient = false) {
|
|
|
|
|
auto input = proto_->mutable_inputs()->Add();
|
|
|
|
|
*input->mutable_name() = name;
|
|
|
|
|
*input->mutable_comment() = comment;
|
|
|
|
|
input->set_ignore_gradient(ignore_gradient);
|
|
|
|
|
input->set_multiple(multiple);
|
|
|
|
|
if (multiple) {
|
|
|
|
|
SetHasMultipleInput();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddInputs(const std::string& name, const std::string& comment) {
|
|
|
|
|
AddInput(name, comment, true);
|
|
|
|
|
void AddInputs(const std::string& name, const std::string& comment,
|
|
|
|
|
bool ignore_gradient = false) {
|
|
|
|
|
AddInput(name, comment, true, ignore_gradient);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddOutput(const std::string& name, const std::string& comment,
|
|
|
|
|
bool temporary = false, bool multiple = false) {
|
|
|
|
|
bool temporary = false, bool multiple = false,
|
|
|
|
|
bool ignore_gradient = false) {
|
|
|
|
|
auto output = proto_->mutable_outputs()->Add();
|
|
|
|
|
*output->mutable_name() = name;
|
|
|
|
|
*output->mutable_comment() = comment;
|
|
|
|
|
output->set_ignore_gradient(ignore_gradient);
|
|
|
|
|
output->set_multiple(multiple);
|
|
|
|
|
if (multiple) {
|
|
|
|
|
SetHasMultipleOutput();
|
|
|
|
@ -103,8 +121,8 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void AddOutputs(const std::string& name, const std::string& comment,
|
|
|
|
|
bool temporary = false) {
|
|
|
|
|
AddOutput(name, comment, temporary, true);
|
|
|
|
|
bool temporary = false, bool ignore_gradient = false) {
|
|
|
|
|
AddOutput(name, comment, temporary, true, ignore_gradient);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -204,9 +222,9 @@ class OpRegistry {
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
|
static void RegisterOp(const std::string& op_type) {
|
|
|
|
|
creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
op_creators()[op_type] = [] { return new OpType; };
|
|
|
|
|
OpAttrChecker& op_checker = op_checkers()[op_type];
|
|
|
|
|
OpProto& op_proto = protos()[op_type];
|
|
|
|
|
auto maker = ProtoMakerType(&op_proto, &op_checker);
|
|
|
|
|
maker.Validate();
|
|
|
|
|
*op_proto.mutable_type() = op_type;
|
|
|
|
@ -227,18 +245,26 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorPtr CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != creators().end(),
|
|
|
|
|
"Operator %s cannot be found", type);
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
static void RegisterGradOp(const std::string& op_type,
|
|
|
|
|
const std::string& grad_op_type) {
|
|
|
|
|
op_creators()[grad_op_type] = [] { return new GradOpType; };
|
|
|
|
|
grad_ops()[op_type] = grad_op_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
|
"Operator %s cannot be found.", type);
|
|
|
|
|
|
|
|
|
|
auto op = op_create_it->second();
|
|
|
|
|
op->type_ = type;
|
|
|
|
|
op->inputs_ = inputs;
|
|
|
|
|
op->outputs_ = outputs;
|
|
|
|
|
|
|
|
|
|
op->attrs_ = attrs;
|
|
|
|
|
op_checkers().at(type).Check(op->attrs_);
|
|
|
|
|
|
|
|
|
@ -252,10 +278,10 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->Init();
|
|
|
|
|
return OperatorPtr(op);
|
|
|
|
|
return std::shared_ptr<OperatorBase>(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static OperatorPtr CreateOp(const OpDesc& op_desc) {
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
|
|
|
|
|
std::vector<std::string> inputs;
|
|
|
|
|
inputs.reserve((size_t)op_desc.inputs_size());
|
|
|
|
|
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
|
|
|
|
@ -274,18 +300,41 @@ class OpRegistry {
|
|
|
|
|
return CreateOp(op_desc.type(), inputs, outputs, attrs);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateGradOp(
|
|
|
|
|
std::shared_ptr<OperatorBase> op) {
|
|
|
|
|
GradOpBuilder builder(op.get());
|
|
|
|
|
std::shared_ptr<OperatorBase> grad_op(builder.Build());
|
|
|
|
|
grad_op->Init();
|
|
|
|
|
return grad_op;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpProto>& protos() {
|
|
|
|
|
static std::unordered_map<std::string, OpProto> protos_;
|
|
|
|
|
return protos_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, std::string>& grad_ops() {
|
|
|
|
|
static std::unordered_map<std::string, std::string> grad_ops_;
|
|
|
|
|
return grad_ops_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>>&
|
|
|
|
|
VarIndexMaps() {
|
|
|
|
|
static std::unordered_map<std::string, std::shared_ptr<VarIndexMap>> maps_;
|
|
|
|
|
return maps_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& op_creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> op_creators_;
|
|
|
|
|
return op_creators_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& outname : op->outputs_) {
|
|
|
|
@ -296,16 +345,6 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpCreator>& creators() {
|
|
|
|
|
static std::unordered_map<std::string, OpCreator> creators_;
|
|
|
|
|
return creators_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker>& op_checkers() {
|
|
|
|
|
static std::unordered_map<std::string, OpAttrChecker> op_checkers_;
|
|
|
|
|
return op_checkers_;
|
|
|
|
|
};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
@ -316,6 +355,14 @@ class OpRegisterHelper {
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
template <typename GradOpType>
|
|
|
|
|
class GradOpRegisterHelper {
|
|
|
|
|
public:
|
|
|
|
|
GradOpRegisterHelper(const char* op_type, const char* grad_op_type) {
|
|
|
|
|
OpRegistry::RegisterGradOp<GradOpType>(op_type, grad_op_type);
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* check if MACRO is used in GLOBAL NAMESPACE.
|
|
|
|
|
*/
|
|
|
|
@ -335,6 +382,20 @@ class OpRegisterHelper {
|
|
|
|
|
__op_register_##__op_type##__(#__op_type); \
|
|
|
|
|
int __op_register_##__op_type##_handle__() { return 0; }
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register Gradient Operator.
|
|
|
|
|
*/
|
|
|
|
|
#define REGISTER_GRADIENT_OP(__op_type, __grad_op_type, __grad_op_class) \
|
|
|
|
|
STATIC_ASSERT_GLOBAL_NAMESPACE( \
|
|
|
|
|
__reg_gradient_op__##__op_type##__grad_op_type, \
|
|
|
|
|
"REGISTER_GRADIENT_OP must be in global namespace"); \
|
|
|
|
|
static ::paddle::framework::GradOpRegisterHelper<__grad_op_class> \
|
|
|
|
|
__op_gradient_register_##__op_type##__grad_op_type##__(#__op_type, \
|
|
|
|
|
#__grad_op_type); \
|
|
|
|
|
int __op_gradient_register_##__op_type##__grad_op_type##_handle__() { \
|
|
|
|
|
return 0; \
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/**
|
|
|
|
|
* Macro to Register OperatorKernel.
|
|
|
|
|
*/
|
|
|
|
|