|
|
|
@ -14,6 +14,7 @@ limitations under the License. */
|
|
|
|
|
|
|
|
|
|
#include "paddle/pybind/protobuf.h"
|
|
|
|
|
#include <deque>
|
|
|
|
|
#include "paddle/framework/attribute.h"
|
|
|
|
|
|
|
|
|
|
namespace paddle {
|
|
|
|
|
namespace pybind {
|
|
|
|
@ -56,10 +57,90 @@ private:
|
|
|
|
|
|
|
|
|
|
class OpDescBind {
|
|
|
|
|
public:
|
|
|
|
|
OpDesc *Proto() { return &op_desc_; }
|
|
|
|
|
OpDesc *Proto() {
|
|
|
|
|
Sync();
|
|
|
|
|
return &op_desc_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string Type() const { return op_desc_.type(); }
|
|
|
|
|
|
|
|
|
|
void SetType(const std::string &type) { op_desc_.set_type(type); }
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> &Input(const std::string &name) const {
|
|
|
|
|
auto it = inputs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
it != inputs_.end(), "Input %s cannot be found in Op %s", name, Type());
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> InputNames() const {
|
|
|
|
|
std::vector<std::string> retv;
|
|
|
|
|
retv.reserve(this->inputs_.size());
|
|
|
|
|
for (auto &ipt : this->inputs_) {
|
|
|
|
|
retv.push_back(ipt.first);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetInput(const std::string ¶m_name,
|
|
|
|
|
const std::vector<std::string> &args) {
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
inputs_[param_name] = args;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> &Output(const std::string &name) const {
|
|
|
|
|
auto it = outputs_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it != outputs_.end(),
|
|
|
|
|
"Output %s cannot be found in Op %s",
|
|
|
|
|
name,
|
|
|
|
|
Type());
|
|
|
|
|
return it->second;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::string> OutputNames() const {
|
|
|
|
|
std::vector<std::string> retv;
|
|
|
|
|
retv.reserve(this->outputs_.size());
|
|
|
|
|
for (auto &ipt : this->outputs_) {
|
|
|
|
|
retv.push_back(ipt.first);
|
|
|
|
|
}
|
|
|
|
|
return retv;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetOutput(const std::string ¶m_name,
|
|
|
|
|
const std::vector<std::string> &args) {
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
this->outputs_[param_name] = args;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::string DebugString() { return this->Proto()->DebugString(); }
|
|
|
|
|
|
|
|
|
|
void Sync() {
|
|
|
|
|
if (need_update_) {
|
|
|
|
|
this->op_desc_.mutable_inputs()->Clear();
|
|
|
|
|
for (auto &ipt : inputs_) {
|
|
|
|
|
auto *input = op_desc_.add_inputs();
|
|
|
|
|
input->set_parameter(ipt.first);
|
|
|
|
|
VectorToRepeated(ipt.second, input->mutable_arguments());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
this->op_desc_.mutable_outputs()->Clear();
|
|
|
|
|
for (auto &opt : outputs_) {
|
|
|
|
|
auto *output = op_desc_.add_outputs();
|
|
|
|
|
output->set_parameter(opt.first);
|
|
|
|
|
VectorToRepeated(opt.second, output->mutable_arguments());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
need_update_ = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
OpDesc op_desc_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> inputs_;
|
|
|
|
|
std::unordered_map<std::string, std::vector<std::string>> outputs_;
|
|
|
|
|
std::unordered_map<std::string, Attribute> attrs_;
|
|
|
|
|
|
|
|
|
|
bool need_update_{false};
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
class BlockDescBind {
|
|
|
|
@ -141,8 +222,6 @@ public:
|
|
|
|
|
return blocks_.back().get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlockDescBind *Root() { return blocks_.front().get(); }
|
|
|
|
|
|
|
|
|
|
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
|
|
|
|
|
|
|
|
|
|
std::string DebugString() { return Proto()->DebugString(); }
|
|
|
|
@ -196,9 +275,6 @@ void BindProgramDesc(py::module &m) {
|
|
|
|
|
.def("append_block",
|
|
|
|
|
&ProgramDescBind::AppendBlock,
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
.def("root_block",
|
|
|
|
|
&ProgramDescBind::Root,
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
|
|
|
|
|
.def("__str__", &ProgramDescBind::DebugString)
|
|
|
|
|
.def("num_blocks", &ProgramDescBind::Size);
|
|
|
|
@ -241,52 +317,17 @@ void BindVarDsec(py::module &m) {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindOpDesc(py::module &m) {
|
|
|
|
|
// auto op_desc_set_var = [](OpDesc::Var *var,
|
|
|
|
|
// const std::string ¶meter,
|
|
|
|
|
// const std::vector<std::string> &arguments) {
|
|
|
|
|
// var->set_parameter(parameter);
|
|
|
|
|
// VectorToRepeated(arguments, var->mutable_arguments());
|
|
|
|
|
// };
|
|
|
|
|
//
|
|
|
|
|
// auto op_desc_set_attr = [](OpDesc &desc, const std::string &name) {
|
|
|
|
|
// auto attr = desc.add_attrs();
|
|
|
|
|
// attr->set_name(name);
|
|
|
|
|
// return attr;
|
|
|
|
|
// };
|
|
|
|
|
py::class_<OpDescBind>(m, "OpDesc", "");
|
|
|
|
|
|
|
|
|
|
// .def("type", [](OpDesc &op) { return op.type(); })
|
|
|
|
|
// .def("set_input",
|
|
|
|
|
// [op_desc_set_var](OpDesc &self,
|
|
|
|
|
// const std::string ¶meter,
|
|
|
|
|
// const std::vector<std::string> &arguments) {
|
|
|
|
|
// auto ipt = self.add_inputs();
|
|
|
|
|
// op_desc_set_var(ipt, parameter, arguments);
|
|
|
|
|
// })
|
|
|
|
|
// .def("input_names",
|
|
|
|
|
// [](OpDesc &self) {
|
|
|
|
|
// std::vector<std::string> ret_val;
|
|
|
|
|
// ret_val.reserve(static_cast<size_t>(self.inputs().size()));
|
|
|
|
|
// std::transform(
|
|
|
|
|
// self.inputs().begin(),
|
|
|
|
|
// self.inputs().end(),
|
|
|
|
|
// std::back_inserter(ret_val),
|
|
|
|
|
// [](const OpDesc::Var &var) { return var.parameter(); });
|
|
|
|
|
// return ret_val;
|
|
|
|
|
// })
|
|
|
|
|
// .def("__str__", [](OpDesc &self) { return self.DebugString(); })
|
|
|
|
|
// .def("set_output",
|
|
|
|
|
// [op_desc_set_var](OpDesc &self,
|
|
|
|
|
// const std::string ¶meter,
|
|
|
|
|
// const std::vector<std::string> &arguments) {
|
|
|
|
|
// auto opt = self.add_outputs();
|
|
|
|
|
// op_desc_set_var(opt, parameter, arguments);
|
|
|
|
|
// })
|
|
|
|
|
// .def("set_attr",
|
|
|
|
|
// [op_desc_set_attr](OpDesc &self, const std::string &name, int i)
|
|
|
|
|
// {
|
|
|
|
|
// op_desc_set_attr(self, name)->set_i(i);
|
|
|
|
|
// });
|
|
|
|
|
py::class_<OpDescBind>(m, "OpDesc", "")
|
|
|
|
|
.def("type", &OpDescBind::Type)
|
|
|
|
|
.def("set_type", &OpDescBind::SetType)
|
|
|
|
|
.def("input", &OpDescBind::Input)
|
|
|
|
|
.def("input_names", &OpDescBind::InputNames)
|
|
|
|
|
.def("set_input", &OpDescBind::SetInput)
|
|
|
|
|
.def("output", &OpDescBind::Output)
|
|
|
|
|
.def("output_names", &OpDescBind::OutputNames)
|
|
|
|
|
.def("set_output", &OpDescBind::SetOutput)
|
|
|
|
|
.def("__str__", &OpDescBind::DebugString)
|
|
|
|
|
.def("__repr__", &OpDescBind::DebugString);
|
|
|
|
|
}
|
|
|
|
|
} // namespace pybind
|
|
|
|
|
} // namespace paddle
|
|
|
|
|