|
|
|
@ -20,8 +20,8 @@ limitations under the License. */
|
|
|
|
|
#include <unordered_map>
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
#include "paddle/framework/framework.pb.h"
|
|
|
|
|
#include "paddle/framework/grad_op_builder.h"
|
|
|
|
|
#include "paddle/framework/op_desc.pb.h"
|
|
|
|
|
#include "paddle/framework/scope.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
@ -44,25 +44,20 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
|
|
|
|
|
protected:
|
|
|
|
|
struct VariableBuilder {
|
|
|
|
|
VarProto* var_;
|
|
|
|
|
std::function<void()> on_multiple_;
|
|
|
|
|
std::function<void()> on_temporary_;
|
|
|
|
|
OpProto::Var* var_;
|
|
|
|
|
|
|
|
|
|
VariableBuilder& SetMultiple() {
|
|
|
|
|
var_->set_multiple(true);
|
|
|
|
|
on_multiple_();
|
|
|
|
|
var_->set_duplicable(true);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VariableBuilder& SetTemporary() {
|
|
|
|
|
PADDLE_ENFORCE(bool(on_temporary_), "Cannot set temporary");
|
|
|
|
|
var_->set_temporary(true);
|
|
|
|
|
on_temporary_();
|
|
|
|
|
var_->set_intermediate(true);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VariableBuilder& IgnoreGradient() {
|
|
|
|
|
var_->set_ignore_gradient(true);
|
|
|
|
|
var_->set_no_gradient(true);
|
|
|
|
|
return *this;
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
@ -72,8 +67,7 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
auto input = proto_->mutable_inputs()->Add();
|
|
|
|
|
*input->mutable_name() = name;
|
|
|
|
|
*input->mutable_comment() = comment;
|
|
|
|
|
return VariableBuilder{input, [=] { this->SetHasMultipleInput(); },
|
|
|
|
|
nullptr};
|
|
|
|
|
return VariableBuilder{input};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VariableBuilder AddOutput(const std::string& name,
|
|
|
|
@ -81,8 +75,7 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
auto output = proto_->mutable_outputs()->Add();
|
|
|
|
|
*output->mutable_name() = name;
|
|
|
|
|
*output->mutable_comment() = comment;
|
|
|
|
|
return VariableBuilder{output, [=] { this->SetHasMultipleOutput(); },
|
|
|
|
|
[=] { this->SetHasTemporaryOutput(); }};
|
|
|
|
|
return VariableBuilder{output};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
@ -102,53 +95,6 @@ class OpProtoAndCheckerMaker {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
void SetHasMultiple(const std::string& in_out, bool* flag) {
|
|
|
|
|
if (!*flag) {
|
|
|
|
|
AddAttr<std::vector<int>>(in_out + "_format",
|
|
|
|
|
"The multiple index of " + in_out +
|
|
|
|
|
"\n"
|
|
|
|
|
R"DOC(
|
|
|
|
|
This attribute is used by Paddle core framework. Paddle's Op support each input
|
|
|
|
|
or output could be a list of variable. This attribute is used to show how that
|
|
|
|
|
list organized.
|
|
|
|
|
|
|
|
|
|
e.g.
|
|
|
|
|
input = ["a", "b", "c", "d", "e", "f"]
|
|
|
|
|
input_format = [0, 4, 5, 6]
|
|
|
|
|
|
|
|
|
|
means
|
|
|
|
|
The number of all input variables this op is six, and they are segmented into
|
|
|
|
|
three inputs.
|
|
|
|
|
|
|
|
|
|
The first input is input[0:4], second is input[4:5], third is input[5:6].
|
|
|
|
|
)DOC",
|
|
|
|
|
/*generated*/ true);
|
|
|
|
|
*flag = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetHasMultipleInput() { SetHasMultiple("input", &has_multiple_input_); }
|
|
|
|
|
void SetHasMultipleOutput() {
|
|
|
|
|
SetHasMultiple("output", &has_multiple_output_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetHasTemporaryOutput() {
|
|
|
|
|
if (!has_temporary_output_) {
|
|
|
|
|
AddAttr<std::vector<int>>("temporary_index",
|
|
|
|
|
R"DOC(The temporary index of output.
|
|
|
|
|
|
|
|
|
|
Not all output of Paddle Op is used by user. For faster computation, each op
|
|
|
|
|
could output some its internal state to other op, other op could take that
|
|
|
|
|
output to make compute faster.
|
|
|
|
|
|
|
|
|
|
Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
)DOC",
|
|
|
|
|
/*generated*/ true)
|
|
|
|
|
.SetDefault(std::vector<int>());
|
|
|
|
|
has_temporary_output_ = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CheckNoDuplicatedInOutAttrs() {
|
|
|
|
|
std::unordered_set<std::string> names;
|
|
|
|
|
auto checker = [&](const std::string& name) {
|
|
|
|
@ -169,15 +115,12 @@ Add a mark to which output is temporary is helpful for future optimization.
|
|
|
|
|
OpProto* proto_;
|
|
|
|
|
OpAttrChecker* op_checker_;
|
|
|
|
|
bool validated_{false};
|
|
|
|
|
bool has_multiple_input_{false};
|
|
|
|
|
bool has_multiple_output_{false};
|
|
|
|
|
bool has_temporary_output_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class OpRegistry {
|
|
|
|
|
using OpCreator = std::function<OperatorBase*()>;
|
|
|
|
|
using VarIndexMap = std::unordered_map<std::string, int>;
|
|
|
|
|
using VarNameList = std::vector<std::string>;
|
|
|
|
|
using VarNameMap = std::unordered_map<std::string, std::vector<std::string>>;
|
|
|
|
|
|
|
|
|
|
public:
|
|
|
|
|
template <typename OpType, typename ProtoMakerType>
|
|
|
|
@ -213,8 +156,8 @@ class OpRegistry {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const std::string& type,
|
|
|
|
|
const VarNameList& inputs,
|
|
|
|
|
const VarNameList& outputs,
|
|
|
|
|
const VarNameMap& inputs,
|
|
|
|
|
const VarNameMap& outputs,
|
|
|
|
|
const AttributeMap& attrs) {
|
|
|
|
|
auto op_create_it = op_creators().find(type);
|
|
|
|
|
PADDLE_ENFORCE(op_create_it != op_creators().end(),
|
|
|
|
@ -230,27 +173,28 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
GenerateTempVariableName(op);
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
auto var_index_it = VarIndexMaps().find(type);
|
|
|
|
|
if (var_index_it != VarIndexMaps().end()) {
|
|
|
|
|
op->in_out_idxs_ = var_index_it->second;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
op->Init();
|
|
|
|
|
return std::shared_ptr<OperatorBase>(op);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static std::shared_ptr<OperatorBase> CreateOp(const OpDesc& op_desc) {
|
|
|
|
|
std::vector<std::string> inputs;
|
|
|
|
|
inputs.reserve((size_t)op_desc.inputs_size());
|
|
|
|
|
std::copy(op_desc.inputs().begin(), op_desc.inputs().end(),
|
|
|
|
|
std::back_inserter(inputs));
|
|
|
|
|
VarNameMap inputs;
|
|
|
|
|
for (auto& input : op_desc.inputs()) {
|
|
|
|
|
auto& var_names = inputs[input.op_proto_name()];
|
|
|
|
|
auto& var_names_in_proto = input.var_names();
|
|
|
|
|
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
|
|
|
|
|
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
|
|
|
|
|
std::back_inserter(var_names));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> outputs;
|
|
|
|
|
outputs.reserve((size_t)op_desc.outputs_size());
|
|
|
|
|
std::copy(op_desc.outputs().begin(), op_desc.outputs().end(),
|
|
|
|
|
std::back_inserter(outputs));
|
|
|
|
|
VarNameMap outputs;
|
|
|
|
|
for (auto& output : op_desc.outputs()) {
|
|
|
|
|
auto& var_names = outputs[output.op_proto_name()];
|
|
|
|
|
auto& var_names_in_proto = output.var_names();
|
|
|
|
|
var_names.reserve(static_cast<size_t>(var_names_in_proto.size()));
|
|
|
|
|
std::copy(var_names_in_proto.begin(), var_names_in_proto.end(),
|
|
|
|
|
std::back_inserter(var_names));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
AttributeMap attrs;
|
|
|
|
|
for (auto& attr : op_desc.attrs()) {
|
|
|
|
@ -303,11 +247,13 @@ class OpRegistry {
|
|
|
|
|
|
|
|
|
|
static void GenerateTempVariableName(OperatorBase* op) {
|
|
|
|
|
static std::atomic<size_t> gUniqId(0UL);
|
|
|
|
|
for (auto& outname : op->outputs_) {
|
|
|
|
|
if (outname == kTempVarName) {
|
|
|
|
|
outname += op->type_;
|
|
|
|
|
outname += "@";
|
|
|
|
|
outname += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
for (auto& output : op->outputs_) {
|
|
|
|
|
for (auto& output_name : output.second) {
|
|
|
|
|
if (output_name == kTempVarName) {
|
|
|
|
|
output_name += op->type_;
|
|
|
|
|
output_name += "@";
|
|
|
|
|
output_name += std::to_string(gUniqId.fetch_add(1));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|