|
|
@ -85,6 +85,7 @@ namespace pybind {
|
|
|
|
|
|
|
|
|
|
|
|
using namespace paddle::framework; // NOLINT
|
|
|
|
using namespace paddle::framework; // NOLINT
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// convert between std::vector and protobuf repeated.
|
|
|
|
template <typename T>
|
|
|
|
template <typename T>
|
|
|
|
inline std::vector<T> RepeatedToVector(
|
|
|
|
inline std::vector<T> RepeatedToVector(
|
|
|
|
const google::protobuf::RepeatedField<T> &repeated_field) {
|
|
|
|
const google::protobuf::RepeatedField<T> &repeated_field) {
|
|
|
@ -104,6 +105,7 @@ inline void VectorToRepeated(const std::vector<T> &vec,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Specialize vector<bool>.
|
|
|
|
template <typename RepeatedField>
|
|
|
|
template <typename RepeatedField>
|
|
|
|
inline void VectorToRepeated(const std::vector<bool> &vec,
|
|
|
|
inline void VectorToRepeated(const std::vector<bool> &vec,
|
|
|
|
RepeatedField *repeated_field) {
|
|
|
|
RepeatedField *repeated_field) {
|
|
|
@ -118,13 +120,16 @@ class OpDescBind;
|
|
|
|
class BlockDescBind;
|
|
|
|
class BlockDescBind;
|
|
|
|
class VarDescBind;
|
|
|
|
class VarDescBind;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
|
|
|
|
|
|
|
|
// read/write speed. Only when we want the protobuf message, the local changes
|
|
|
|
|
|
|
|
// will be synchronized (by `Sync` method).
|
|
|
|
class VarDescBind {
|
|
|
|
class VarDescBind {
|
|
|
|
public:
|
|
|
|
public:
|
|
|
|
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
|
|
|
|
explicit VarDescBind(const std::string &name) { desc_.set_name(name); }
|
|
|
|
|
|
|
|
|
|
|
|
VarDesc *Proto() { return &desc_; }
|
|
|
|
VarDesc *Proto() { return &desc_; }
|
|
|
|
|
|
|
|
|
|
|
|
py::bytes Name() { return desc_.name(); }
|
|
|
|
py::bytes Name() const { return desc_.name(); }
|
|
|
|
|
|
|
|
|
|
|
|
void SetShape(const std::vector<int64_t> &dims) {
|
|
|
|
void SetShape(const std::vector<int64_t> &dims) {
|
|
|
|
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
|
|
|
|
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
|
|
|
@ -134,11 +139,13 @@ public:
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(data_type);
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(data_type);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> Shape() {
|
|
|
|
std::vector<int64_t> Shape() const {
|
|
|
|
return RepeatedToVector(desc_.lod_tensor().dims());
|
|
|
|
return RepeatedToVector(desc_.lod_tensor().dims());
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
framework::DataType DataType() { return desc_.lod_tensor().data_type(); }
|
|
|
|
framework::DataType DataType() const {
|
|
|
|
|
|
|
|
return desc_.lod_tensor().data_type();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
|
VarDesc desc_;
|
|
|
|
VarDesc desc_;
|
|
|
@ -283,16 +290,16 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
|
|
|
void SetBlockAttr(const std::string &name, BlockDescBind &block);
|
|
|
|
|
|
|
|
|
|
|
|
int GetBlockAttr(const std::string &name) const {
|
|
|
|
Attribute GetAttr(const std::string &name) const {
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
return it->second;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Attribute GetAttr(const std::string &name) const {
|
|
|
|
int GetBlockAttr(const std::string &name) const {
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
auto it = attrs_.find(name);
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
|
|
|
|
return it->second;
|
|
|
|
return boost::get<BlockDesc *>(it->second)->idx();
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
private:
|
|
|
@ -312,7 +319,7 @@ public:
|
|
|
|
BlockDescBind(const BlockDescBind &o) = delete;
|
|
|
|
BlockDescBind(const BlockDescBind &o) = delete;
|
|
|
|
BlockDescBind &operator=(const BlockDescBind &o) = delete;
|
|
|
|
BlockDescBind &operator=(const BlockDescBind &o) = delete;
|
|
|
|
|
|
|
|
|
|
|
|
int32_t id() const { return desc_->idx(); }
|
|
|
|
int32_t ID() const { return desc_->idx(); }
|
|
|
|
|
|
|
|
|
|
|
|
int32_t Parent() const { return desc_->parent_idx(); }
|
|
|
|
int32_t Parent() const { return desc_->parent_idx(); }
|
|
|
|
|
|
|
|
|
|
|
@ -410,7 +417,7 @@ public:
|
|
|
|
|
|
|
|
|
|
|
|
BlockDescBind *AppendBlock(const BlockDescBind &parent) {
|
|
|
|
BlockDescBind *AppendBlock(const BlockDescBind &parent) {
|
|
|
|
auto *b = prog_->add_blocks();
|
|
|
|
auto *b = prog_->add_blocks();
|
|
|
|
b->set_parent_idx(parent.id());
|
|
|
|
b->set_parent_idx(parent.ID());
|
|
|
|
b->set_idx(prog_->blocks_size() - 1);
|
|
|
|
b->set_idx(prog_->blocks_size() - 1);
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, b));
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, b));
|
|
|
|
return blocks_.back().get();
|
|
|
|
return blocks_.back().get();
|
|
|
@ -454,6 +461,7 @@ void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
|
|
|
|
this->attrs_[name] = desc;
|
|
|
|
this->attrs_[name] = desc;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// Bind Methods
|
|
|
|
void BindProgramDesc(py::module &m) {
|
|
|
|
void BindProgramDesc(py::module &m) {
|
|
|
|
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
|
|
|
|
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
|
|
|
|
.def_static("instance",
|
|
|
|
.def_static("instance",
|
|
|
@ -481,7 +489,7 @@ void BindProgramDesc(py::module &m) {
|
|
|
|
|
|
|
|
|
|
|
|
void BindBlockDesc(py::module &m) {
|
|
|
|
void BindBlockDesc(py::module &m) {
|
|
|
|
py::class_<BlockDescBind>(m, "BlockDesc", "")
|
|
|
|
py::class_<BlockDescBind>(m, "BlockDesc", "")
|
|
|
|
.def_property_readonly("id", &BlockDescBind::id)
|
|
|
|
.def_property_readonly("id", &BlockDescBind::ID)
|
|
|
|
.def_property_readonly("parent", &BlockDescBind::Parent)
|
|
|
|
.def_property_readonly("parent", &BlockDescBind::Parent)
|
|
|
|
.def("append_op",
|
|
|
|
.def("append_op",
|
|
|
|
&BlockDescBind::AppendOp,
|
|
|
|
&BlockDescBind::AppendOp,
|
|
|
|