|
|
|
@ -42,15 +42,23 @@ inline void VectorToRepeated(const std::vector<T> &vec,
|
|
|
|
|
class ProgramDescBind;
|
|
|
|
|
class OpDescBind;
|
|
|
|
|
class BlockDescBind;
|
|
|
|
|
class VarDescBind;
|
|
|
|
|
|
|
|
|
|
class OpDescBind {
|
|
|
|
|
class VarDescBind {
|
|
|
|
|
public:
|
|
|
|
|
explicit OpDescBind(BlockDescBind *block) : block_(block) {}
|
|
|
|
|
explicit VarDescBind(const std::string &name) { var_desc_.set_name(name); }
|
|
|
|
|
|
|
|
|
|
VarDesc *Proto() { return &var_desc_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
VarDesc var_desc_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
operator OpDesc *() { return &op_desc_; }
|
|
|
|
|
class OpDescBind {
|
|
|
|
|
public:
|
|
|
|
|
OpDesc *Proto() { return &op_desc_; }
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
BlockDescBind *block_;
|
|
|
|
|
OpDesc op_desc_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
@ -59,14 +67,28 @@ public:
|
|
|
|
|
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
|
|
|
|
|
: prog_(prog), desc_(desc), need_update_(false) {}
|
|
|
|
|
|
|
|
|
|
BlockDescBind(const BlockDescBind &o) = delete;
|
|
|
|
|
BlockDescBind &operator=(const BlockDescBind &o) = delete;
|
|
|
|
|
|
|
|
|
|
int32_t id() const { return desc_->idx(); }
|
|
|
|
|
|
|
|
|
|
int32_t Parent() const { return desc_->parent_idx(); }
|
|
|
|
|
|
|
|
|
|
VarDescBind *NewVar(const std::string &name) {
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
auto it = vars_.find(name);
|
|
|
|
|
PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name);
|
|
|
|
|
auto var = new VarDescBind(name);
|
|
|
|
|
vars_[name].reset(var);
|
|
|
|
|
return var;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlockDescBind *ParentBlock() const;
|
|
|
|
|
|
|
|
|
|
OpDescBind *AppendOp() {
|
|
|
|
|
need_update_ = true;
|
|
|
|
|
ops_.emplace_back(this);
|
|
|
|
|
return &ops_.back();
|
|
|
|
|
ops_.emplace_back(new OpDescBind());
|
|
|
|
|
return ops_.back().get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Sync() {
|
|
|
|
@ -75,8 +97,9 @@ public:
|
|
|
|
|
op_field.Clear();
|
|
|
|
|
op_field.Reserve(static_cast<int>(ops_.size()));
|
|
|
|
|
for (auto &op_desc : ops_) {
|
|
|
|
|
op_field.AddAllocated(op_desc);
|
|
|
|
|
op_field.AddAllocated(op_desc->Proto());
|
|
|
|
|
}
|
|
|
|
|
need_update_ = false;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -85,7 +108,8 @@ private:
|
|
|
|
|
BlockDesc *desc_; // not_own
|
|
|
|
|
bool need_update_;
|
|
|
|
|
|
|
|
|
|
std::deque<OpDescBind> ops_;
|
|
|
|
|
std::deque<std::unique_ptr<OpDescBind>> ops_;
|
|
|
|
|
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
using ProgDescMap =
|
|
|
|
@ -106,18 +130,20 @@ public:
|
|
|
|
|
}
|
|
|
|
|
return *ptr;
|
|
|
|
|
}
|
|
|
|
|
ProgramDescBind(const ProgramDescBind &o) = delete;
|
|
|
|
|
ProgramDescBind &operator=(const ProgramDescBind &o) = delete;
|
|
|
|
|
|
|
|
|
|
BlockDescBind *AppendBlock(const BlockDescBind &parent) {
|
|
|
|
|
auto *b = prog_->add_blocks();
|
|
|
|
|
b->set_parent_idx(parent.id());
|
|
|
|
|
b->set_idx(prog_->blocks_size() - 1);
|
|
|
|
|
blocks_.emplace_back(this, b);
|
|
|
|
|
return &blocks_.back();
|
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, b));
|
|
|
|
|
return blocks_.back().get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BlockDescBind *Root() { return &blocks_.front(); }
|
|
|
|
|
BlockDescBind *Root() { return blocks_.front().get(); }
|
|
|
|
|
|
|
|
|
|
BlockDescBind *Block(size_t idx) { return &blocks_[idx]; }
|
|
|
|
|
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
|
|
|
|
|
|
|
|
|
|
std::string DebugString() { return Proto()->DebugString(); }
|
|
|
|
|
|
|
|
|
@ -125,25 +151,31 @@ public:
|
|
|
|
|
|
|
|
|
|
ProgramDesc *Proto() {
|
|
|
|
|
for (auto &block : blocks_) {
|
|
|
|
|
block.Sync();
|
|
|
|
|
block->Sync();
|
|
|
|
|
}
|
|
|
|
|
return prog_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
private:
|
|
|
|
|
explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) {
|
|
|
|
|
blocks_.reserve(100);
|
|
|
|
|
for (auto &block : *prog->mutable_blocks()) {
|
|
|
|
|
blocks_.emplace_back(this, &block);
|
|
|
|
|
blocks_.emplace_back(new BlockDescBind(this, &block));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Not owned
|
|
|
|
|
ProgramDesc *prog_;
|
|
|
|
|
|
|
|
|
|
std::vector<BlockDescBind> blocks_;
|
|
|
|
|
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
BlockDescBind *BlockDescBind::ParentBlock() const {
|
|
|
|
|
if (this->desc_->parent_idx() == -1) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx()));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BindProgramDesc(py::module &m) {
|
|
|
|
|
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
|
|
|
|
|
.def_static("instance",
|
|
|
|
|