|
|
|
@ -99,7 +99,7 @@ template <typename T, typename RepeatedField>
|
|
|
|
|
inline void VectorToRepeated(const std::vector<T> &vec,
|
|
|
|
|
RepeatedField *repeated_field) {
|
|
|
|
|
repeated_field->Reserve(vec.size());
|
|
|
|
|
for (auto &elem : vec) {
|
|
|
|
|
for (const auto &elem : vec) {
|
|
|
|
|
*repeated_field->Add() = elem;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
@ -124,18 +124,23 @@ public:
|
|
|
|
|
|
|
|
|
|
VarDesc *Proto() { return &desc_; }
|
|
|
|
|
|
|
|
|
|
py::bytes Name() { return desc_.name(); }
|
|
|
|
|
|
|
|
|
|
void SetShape(const std::vector<int64_t> &dims) {
|
|
|
|
|
VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void SetDataType(int type_id) {
|
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(static_cast<DataType>(type_id));
|
|
|
|
|
desc_.mutable_lod_tensor()->set_data_type(
|
|
|
|
|
static_cast<enum DataType>(type_id));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int64_t> Shape() {
|
|
|
|
|
return RepeatedToVector(desc_.lod_tensor().dims());
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int DataType() { return desc_.lod_tensor().data_type(); }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
VarDesc desc_;
|
|
|
|
|
};
|
|
|
|
@ -322,6 +327,22 @@ public:
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
VarDescBind *Var(py::bytes name_bytes) const {
|
|
|
|
|
std::string name = name_bytes;
|
|
|
|
|
auto it = vars_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
it != vars_.end(), "Can not find variable %s in current block.", name);
|
|
|
|
|
return it->second.get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<VarDescBind *> AllVars() const {
|
|
|
|
|
std::vector<VarDescBind *> res;
|
|
|
|
|
for (const auto &p : vars_) {
|
|
|
|
|
res.push_back(p.second.get());
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlockDescBind *ParentBlock() const;
|
|
|
|
|
|
|
|
|
|
OpDescBind *AppendOp() {
|
|
|
|
@ -336,6 +357,14 @@ public:
|
|
|
|
|
return ops_.front().get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<OpDescBind *> AllOps() const {
|
|
|
|
|
std::vector<OpDescBind *> res;
|
|
|
|
|
for (const auto &op : ops_) {
|
|
|
|
|
res.push_back(op.get());
|
|
|
|
|
}
|
|
|
|
|
return res;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Sync() {
|
|
|
|
|
if (need_update_) {
|
|
|
|
|
auto &op_field = *this->desc_->mutable_ops();
|
|
|
|
@ -461,16 +490,26 @@ void BindBlockDesc(py::module &m) {
|
|
|
|
|
.def("prepend_op",
|
|
|
|
|
&BlockDescBind::PrependOp,
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
.def("new_var",
|
|
|
|
|
&BlockDescBind::NewVar,
|
|
|
|
|
.def(
|
|
|
|
|
"new_var", &BlockDescBind::NewVar, py::return_value_policy::reference)
|
|
|
|
|
.def("var", &BlockDescBind::Var, py::return_value_policy::reference)
|
|
|
|
|
.def("all_vars",
|
|
|
|
|
&BlockDescBind::AllVars,
|
|
|
|
|
py::return_value_policy::reference)
|
|
|
|
|
.def("all_ops",
|
|
|
|
|
&BlockDescBind::AllOps,
|
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindVarDsec(py::module &m) {
|
|
|
|
|
py::class_<VarDescBind>(m, "VarDesc", "")
|
|
|
|
|
.def("name", &VarDescBind::Name, py::return_value_policy::reference)
|
|
|
|
|
.def("set_shape", &VarDescBind::SetShape)
|
|
|
|
|
.def("set_data_type", &VarDescBind::SetDataType)
|
|
|
|
|
.def("shape", &VarDescBind::Shape);
|
|
|
|
|
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference)
|
|
|
|
|
.def("data_type",
|
|
|
|
|
&VarDescBind::DataType,
|
|
|
|
|
py::return_value_policy::reference);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindOpDesc(py::module &m) {
|
|
|
|
|