|
|
@ -46,8 +46,7 @@ struct variant_caster<V<Ts...>> {
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
typename std::enable_if<
|
|
|
|
typename std::enable_if<
|
|
|
|
!std::is_same<T, boost::detail::variant::void_>::value,
|
|
|
|
!std::is_same<T, boost::detail::variant::void_>::value, bool>::type
|
|
|
|
bool>::type
|
|
|
|
|
|
|
|
try_load(handle src, bool convert) {
|
|
|
|
try_load(handle src, bool convert) {
|
|
|
|
auto caster = make_caster<T>();
|
|
|
|
auto caster = make_caster<T>();
|
|
|
|
if (!load_success_ && caster.load(src, convert)) {
|
|
|
|
if (!load_success_ && caster.load(src, convert)) {
|
|
|
@ -71,8 +70,7 @@ struct variant_caster<V<Ts...>> {
|
|
|
|
return load_success_;
|
|
|
|
return load_success_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
static handle cast(Type const &src,
|
|
|
|
static handle cast(Type const &src, return_value_policy policy,
|
|
|
|
return_value_policy policy,
|
|
|
|
|
|
|
|
handle parent) {
|
|
|
|
handle parent) {
|
|
|
|
variant_caster_visitor visitor(policy, parent);
|
|
|
|
variant_caster_visitor visitor(policy, parent);
|
|
|
|
return boost::apply_visitor(visitor, src);
|
|
|
|
return boost::apply_visitor(visitor, src);
|
|
|
@ -101,8 +99,8 @@ inline std::vector<T> RepeatedToVector(
|
|
|
|
const google::protobuf::RepeatedField<T> &repeated_field) {
|
|
|
|
const google::protobuf::RepeatedField<T> &repeated_field) {
|
|
|
|
std::vector<T> ret;
|
|
|
|
std::vector<T> ret;
|
|
|
|
ret.reserve(repeated_field.size());
|
|
|
|
ret.reserve(repeated_field.size());
|
|
|
|
std::copy(
|
|
|
|
std::copy(repeated_field.begin(), repeated_field.end(),
|
|
|
|
repeated_field.begin(), repeated_field.end(), std::back_inserter(ret));
|
|
|
|
std::back_inserter(ret));
|
|
|
|
return ret;
|
|
|
|
return ret;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -134,7 +132,7 @@ class VarDescBind;
|
|
|
|
// read/write speed. Only when we want the protobuf message, the local changes
|
|
|
|
// read/write speed. Only when we want the protobuf message, the local changes
|
|
|
|
// will be synchronized (by `Sync` method).
|
|
|
|
// will be synchronized (by `Sync` method).
|
|
|
|
class VarDescBind {
|
|
|
|
class VarDescBind {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
|
|
|
|
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
|
|
|
|
|
|
|
|
|
|
|
|
VarDesc *Proto() { return &desc_; }
|
|
|
|
VarDesc *Proto() { return &desc_; }
|
|
|
@ -157,12 +155,12 @@ public:
|
|
|
|
return desc_.lod_tensor().data_type();
|
|
|
|
return desc_.lod_tensor().data_type();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
VarDesc desc_;
|
|
|
|
VarDesc desc_;
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class OpDescBind {
|
|
|
|
class OpDescBind {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
OpDesc *Proto() {
|
|
|
|
OpDesc *Proto() {
|
|
|
|
Sync();
|
|
|
|
Sync();
|
|
|
|
return &op_desc_;
|
|
|
|
return &op_desc_;
|
|
|
@ -174,8 +172,8 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> &Input(const std::string &name) const {
|
|
|
|
const std::vector<std::string> &Input(const std::string &name) const {
|
|
|
|
auto it = inputs_.find(name);
|
|
|
|
auto it = inputs_.find(name);
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s",
|
|
|
|
it != inputs_.end(), "Input %s cannot be found in Op %s", name, Type());
|
|
|
|
name, Type());
|
|
|
|
return it->second;
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -196,10 +194,8 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
const std::vector<std::string> &Output(const std::string &name) const {
|
|
|
|
const std::vector<std::string> &Output(const std::string &name) const {
|
|
|
|
auto it = outputs_.find(name);
|
|
|
|
auto it = outputs_.find(name);
|
|
|
|
PADDLE_ENFORCE(it != outputs_.end(),
|
|
|
|
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
|
|
|
|
"Output %s cannot be found in Op %s",
|
|
|
|
name, Type());
|
|
|
|
name,
|
|
|
|
|
|
|
|
Type());
|
|
|
|
|
|
|
|
return it->second;
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -258,7 +254,7 @@ public:
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
struct SetAttrDescVisitor : public boost::static_visitor<void> {
|
|
|
|
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {}
|
|
|
|
mutable OpDesc::Attr *attr_;
|
|
|
|
mutable OpDesc::Attr *attr_;
|
|
|
@ -325,7 +321,7 @@ private:
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
class BlockDescBind {
|
|
|
|
class BlockDescBind {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
|
|
|
|
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
|
|
|
|
: prog_(prog), desc_(desc), need_update_(false) {}
|
|
|
|
: prog_(prog), desc_(desc), need_update_(false) {}
|
|
|
|
|
|
|
|
|
|
|
@ -349,8 +345,8 @@ public:
|
|
|
|
VarDescBind *Var(py::bytes name_bytes) const {
|
|
|
|
VarDescBind *Var(py::bytes name_bytes) const {
|
|
|
|
std::string name = name_bytes;
|
|
|
|
std::string name = name_bytes;
|
|
|
|
auto it = vars_.find(name);
|
|
|
|
auto it = vars_.find(name);
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
PADDLE_ENFORCE(it != vars_.end(),
|
|
|
|
it != vars_.end(), "Can not find variable %s in current block.", name);
|
|
|
|
"Can not find variable %s in current block.", name);
|
|
|
|
return it->second.get();
|
|
|
|
return it->second.get();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -398,7 +394,7 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
BlockDesc *RawPtr() { return desc_; }
|
|
|
|
BlockDesc *RawPtr() { return desc_; }
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
ProgramDescBind *prog_; // not_own
|
|
|
|
ProgramDescBind *prog_; // not_own
|
|
|
|
BlockDesc *desc_; // not_own
|
|
|
|
BlockDesc *desc_; // not_own
|
|
|
|
bool need_update_;
|
|
|
|
bool need_update_;
|
|
|
@ -412,7 +408,7 @@ using ProgDescMap =
|
|
|
|
static ProgDescMap *g_bind_map = nullptr;
|
|
|
|
static ProgDescMap *g_bind_map = nullptr;
|
|
|
|
|
|
|
|
|
|
|
|
class ProgramDescBind {
|
|
|
|
class ProgramDescBind {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
static ProgramDescBind &Instance(ProgramDesc *prog) {
|
|
|
|
static ProgramDescBind &Instance(ProgramDesc *prog) {
|
|
|
|
if (g_bind_map == nullptr) {
|
|
|
|
if (g_bind_map == nullptr) {
|
|
|
|
g_bind_map = new ProgDescMap();
|
|
|
|
g_bind_map = new ProgDescMap();
|
|
|
@ -449,7 +445,7 @@ public:
|
|
|
|
return prog_;
|
|
|
|
return prog_;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
|
|
|
|
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
|
|
|
|
for (auto &block : *prog->mutable_blocks()) {
|
|
|
|
for (auto &block : *prog->mutable_blocks()) {
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, &block));
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, &block));
|
|
|
@ -492,8 +488,7 @@ void BindProgramDesc(py::module &m) {
|
|
|
|
return &ProgramDescBind::Instance(prog_desc);
|
|
|
|
return &ProgramDescBind::Instance(prog_desc);
|
|
|
|
},
|
|
|
|
},
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
.def("append_block",
|
|
|
|
.def("append_block", &ProgramDescBind::AppendBlock,
|
|
|
|
&ProgramDescBind::AppendBlock,
|
|
|
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
|
|
|
|
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference)
|
|
|
|
.def("__str__", &ProgramDescBind::DebugString)
|
|
|
|
.def("__str__", &ProgramDescBind::DebugString)
|
|
|
@ -504,20 +499,16 @@ void BindBlockDesc(py::module &m) {
|
|
|
|
py::class_<BlockDescBind>(m, "BlockDesc", "")
|
|
|
|
py::class_<BlockDescBind>(m, "BlockDesc", "")
|
|
|
|
.def_property_readonly("id", &BlockDescBind::ID)
|
|
|
|
.def_property_readonly("id", &BlockDescBind::ID)
|
|
|
|
.def_property_readonly("parent", &BlockDescBind::Parent)
|
|
|
|
.def_property_readonly("parent", &BlockDescBind::Parent)
|
|
|
|
.def("append_op",
|
|
|
|
.def("append_op", &BlockDescBind::AppendOp,
|
|
|
|
&BlockDescBind::AppendOp,
|
|
|
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
.def("prepend_op",
|
|
|
|
.def("prepend_op", &BlockDescBind::PrependOp,
|
|
|
|
&BlockDescBind::PrependOp,
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
|
|
|
.def("new_var", &BlockDescBind::NewVar,
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
.def(
|
|
|
|
|
|
|
|
"new_var", &BlockDescBind::NewVar, py::return_value_policy::reference)
|
|
|
|
|
|
|
|
.def("var", &BlockDescBind::Var, py::return_value_policy::reference)
|
|
|
|
.def("var", &BlockDescBind::Var, py::return_value_policy::reference)
|
|
|
|
.def("all_vars",
|
|
|
|
.def("all_vars", &BlockDescBind::AllVars,
|
|
|
|
&BlockDescBind::AllVars,
|
|
|
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
.def("all_ops",
|
|
|
|
.def("all_ops", &BlockDescBind::AllOps,
|
|
|
|
&BlockDescBind::AllOps,
|
|
|
|
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|