Rename XXDescBind --> XXDesc (#6797)

* Rename XXDescBind --> XXDesc

* Fix Compile
del_some_in_makelist
Yu Yang 7 years ago committed by GitHub
parent 0295b00066
commit 091897321f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set, const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
OpDescBind op_desc; OpDesc op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops), std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) { [](const std::unique_ptr<OpDesc>& grad_desc) {
return OpRegistry::CreateOp(*grad_desc); return OpRegistry::CreateOp(*grad_desc);
}); });
PADDLE_ENFORCE(!grad_ops.empty()); PADDLE_ENFORCE(!grad_ops.empty());
@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) {
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
size_t grad_op_start_index, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map, const std::unordered_map<std::string, std::string>& param_name_map,
BlockDescBind* block_desc, BlockDesc* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) { std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
@ -350,12 +350,11 @@ static void CreateGradVarInBlock(
} }
} }
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDesc>> MakeOpGrad(
const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars, const OpDesc* op_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block = const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>()) {
std::vector<BlockDescBind*>()) { std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, *no_grad_vars)) { if (AllGradInSet(inputs, *no_grad_vars)) {
@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
.Get(op_desc->Type()) .Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block); .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDesc>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc->InputArgumentNames()) { for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars->count(in_name)) { if (no_grad_vars->count(in_name)) {
@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
desc->Rename(in_name, new_name); desc->Rename(in_name, new_name);
std::unique_ptr<OpDescBind> fill_zeros_op( std::unique_ptr<OpDesc> fill_zeros_op(
new OpDescBind("fill_zeros_like", {{"X", {prefix}}}, new OpDesc("fill_zeros_like", {{"X", {prefix}}},
{{"Y", {new_name}}}, AttributeMap{})); {{"Y", {new_name}}}, AttributeMap{}));
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; return grad_op_descs;
} }
static BlockDescBind* CreateStepBlock( static BlockDesc* CreateStepBlock(
ProgramDescBind& program_desc, ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx); int step_block_idx);
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDesc>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDesc& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
VLOG(5) << "MakeBlockBackward"; VLOG(5) << "MakeBlockBackward";
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx); BlockDesc* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps(); std::vector<OpDesc*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDesc>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
VLOG(5) << "Making backward " << (*it)->Type() << " op"; VLOG(5) << "Making backward " << (*it)->Type() << " op";
std::vector<std::unique_ptr<OpDescBind>> op_grads; std::vector<std::unique_ptr<OpDesc>> op_grads;
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") { if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
int step_block_idx = (*it)->GetBlockAttr("sub_block"); int step_block_idx = (*it)->GetBlockAttr("sub_block");
BlockDescBind* backward_block = CreateStepBlock( BlockDesc* backward_block = CreateStepBlock(program_desc, no_grad_vars,
program_desc, no_grad_vars, grad_to_var, step_block_idx); grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") { } else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block = BlockDesc* backward_block =
CreateStepBlock(program_desc, no_grad_vars, grad_to_var, CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("sub_block")); (*it)->GetBlockAttr("sub_block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
} }
++grad_desc_idx; ++grad_desc_idx;
} }
std::transform( std::transform(op_grads.begin(), op_grads.end(),
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs), std::back_inserter(backward_descs),
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); }); [](std::unique_ptr<OpDesc>& ptr) { return std::move(ptr); });
} }
VLOG(5) << "Appending Sums"; VLOG(5) << "Appending Sums";
// Check whether some variables are written more than once // Check whether some variables are written more than once
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops; std::list<std::pair<size_t, std::unique_ptr<OpDesc>>> pending_sum_ops;
for (const auto& dup : dup_out_ops) { for (const auto& dup : dup_out_ops) {
const std::string& out_name = dup.first; const std::string& out_name = dup.first;
const std::vector<size_t> dup_op = dup.second; const std::vector<size_t> dup_op = dup.second;
@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
next_g_name = sum_op_inputs.back(); next_g_name = sum_op_inputs.back();
} }
std::unique_ptr<OpDescBind> sum_op( std::unique_ptr<OpDesc> sum_op(new OpDesc("sum", {{"X", sum_op_inputs}},
new OpDescBind("sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {{"Out", {out_name}}},
AttributeMap{})); AttributeMap{}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
} }
} }
pending_sum_ops.sort( pending_sum_ops.sort([](const std::pair<size_t, std::unique_ptr<OpDesc>>& a,
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a, const std::pair<size_t, std::unique_ptr<OpDesc>>& b) {
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
return a.first > b.first; return a.first > b.first;
}); });
for (auto& p : pending_sum_ops) { for (auto& p : pending_sum_ops) {
@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
static BlockDescBind* CreateStepBlock( static BlockDesc* CreateStepBlock(
ProgramDescBind& program_desc, ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) { int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx, auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var); no_grad_vars, grad_to_var);
BlockDescBind* backward_block = BlockDesc* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) { for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr)); backward_block->AppendAllocatedOp(move(ptr));
@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock(
} }
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDesc& program_desc, const VarDesc& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names; std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.reserve(no_grad_vars.size() + 1);
@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward(
PADDLE_ENFORCE(is_scalar, "target should be scalar"); PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name() VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType(); << " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op( std::unique_ptr<OpDesc> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, new OpDesc("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", std::vector<int>{1}}, {{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)}, {"value", static_cast<float>(1.0)},
{"dtype", target.GetDataType()}})); {"dtype", target.GetDataType()}}));

@ -49,7 +49,7 @@ using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
GradVarInfo /*grad_var_info*/>; GradVarInfo /*grad_var_info*/>;
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDesc& program_desc, const VarDesc& target,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework

File diff suppressed because it is too large Load Diff

@ -19,18 +19,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
VarDescBind *BlockDescBind::Var(const std::string &name) { VarDesc *BlockDesc::Var(const std::string &name) {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
} }
need_update_ = true; need_update_ = true;
auto *var = new VarDescBind(name); auto *var = new VarDesc(name);
vars_[name].reset(var); vars_[name].reset(var);
return var; return var;
} }
VarDescBind *BlockDescBind::FindVar(const std::string &name) const { VarDesc *BlockDesc::FindVar(const std::string &name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it == vars_.end()) { if (it == vars_.end()) {
return nullptr; return nullptr;
@ -38,11 +38,11 @@ VarDescBind *BlockDescBind::FindVar(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
bool BlockDescBind::HasVar(const std::string &name) const { bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end(); return vars_.find(name) != vars_.end();
} }
VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name); auto it = vars_.find(name);
@ -53,53 +53,52 @@ VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
VarDescBind *BlockDescBind::FindRecursiveOrCreateVar( VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
const std::string &name_bytes) { VarDesc *res = FindVarRecursive(name_bytes);
VarDescBind *res = FindVarRecursive(name_bytes);
if (res == nullptr) { if (res == nullptr) {
res = Var(name_bytes); res = Var(name_bytes);
} }
return res; return res;
} }
bool BlockDescBind::HasVarRecursive(const std::string &name) const { bool BlockDesc::HasVarRecursive(const std::string &name) const {
return FindVarRecursive(name) != nullptr; return FindVarRecursive(name) != nullptr;
} }
std::vector<VarDescBind *> BlockDescBind::AllVars() const { std::vector<VarDesc *> BlockDesc::AllVars() const {
std::vector<VarDescBind *> res; std::vector<VarDesc *> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
res.push_back(p.second.get()); res.push_back(p.second.get());
} }
return res; return res;
} }
OpDescBind *BlockDescBind::AppendOp() { OpDesc *BlockDesc::AppendOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_back(new OpDescBind()); ops_.emplace_back(new OpDesc());
return ops_.back().get(); return ops_.back().get();
} }
void BlockDescBind::AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc) { void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
need_update_ = true; need_update_ = true;
ops_.emplace_back(std::move(op_desc)); ops_.emplace_back(std::move(op_desc));
} }
OpDescBind *BlockDescBind::PrependOp() { OpDesc *BlockDesc::PrependOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_front(new OpDescBind()); ops_.emplace_front(new OpDesc());
return ops_.front().get(); return ops_.front().get();
} }
std::vector<OpDescBind *> BlockDescBind::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDescBind *> res; std::vector<OpDesc *> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
res.push_back(op.get()); res.push_back(op.get());
} }
return res; return res;
} }
void BlockDescBind::Flush() { void BlockDesc::Flush() {
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_desc->Flush(); op_desc->Flush();
} }
@ -121,43 +120,43 @@ void BlockDescBind::Flush() {
} }
} }
BlockDescBind *BlockDescBind::ParentBlock() const { BlockDesc *BlockDesc::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr; return nullptr;
} }
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
proto::BlockDesc *BlockDescBind::Proto() { proto::BlockDesc *BlockDesc::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc) BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) { : prog_(prog), desc_(desc), need_update_(false) {
for (const proto::VarDesc &var_desc : desc_->vars()) { for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog)); ops_.emplace_back(new OpDesc(op_desc, prog));
} }
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc, BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
ProgramDescBind *prog) ProgramDesc *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;
for (auto &op : other.ops_) { for (auto &op : other.ops_) {
ops_.emplace_back(new OpDescBind(*op)); ops_.emplace_back(new OpDesc(*op));
} }
for (auto &it : other.vars_) { for (auto &it : other.vars_) {
auto *var = new VarDescBind(*it.second); auto *var = new VarDesc(*it.second);
vars_[it.first].reset(var); vars_[it.first].reset(var);
} }
} }
void BlockDescBind::ClearPBOps() { void BlockDesc::ClearPBOps() {
auto ops = this->desc_->mutable_ops(); auto ops = this->desc_->mutable_ops();
while (!ops->empty()) { while (!ops->empty()) {
// we do not own the OpDesc, so release the ownership. // we do not own the OpDesc, so release the ownership.
@ -165,7 +164,7 @@ void BlockDescBind::ClearPBOps() {
} }
} }
void BlockDescBind::ClearPBVars() { void BlockDesc::ClearPBVars() {
auto vars = this->desc_->mutable_vars(); auto vars = this->desc_->mutable_vars();
while (!vars->empty()) { while (!vars->empty()) {
// we do not own the VarDesc, so release the ownership. // we do not own the VarDesc, so release the ownership.

@ -28,20 +28,19 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ProgramDescBind; class ProgramDesc;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize // Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes // read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method). // will be synchronized (by `Sync` method).
class BlockDescBind { class BlockDesc {
public: public:
BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc); BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc, BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog);
ProgramDescBind *prog);
~BlockDescBind() { ~BlockDesc() {
this->ClearPBVars(); this->ClearPBVars();
this->ClearPBOps(); this->ClearPBOps();
} }
@ -50,15 +49,15 @@ class BlockDescBind {
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *Var(const std::string &name_bytes); VarDesc *Var(const std::string &name_bytes);
VarDescBind *FindVar(const std::string &name_bytes) const; VarDesc *FindVar(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
VarDescBind *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDescBind *FindRecursiveOrCreateVar(const std::string &name_bytes); VarDesc *FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const; bool HasVarRecursive(const std::string &var_name) const;
@ -70,41 +69,41 @@ class BlockDescBind {
return var_names; return var_names;
} }
std::vector<VarDescBind *> AllVars() const; std::vector<VarDesc *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDesc *ParentBlock() const;
OpDescBind *AppendOp(); OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc); void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
OpDescBind *PrependOp(); OpDesc *PrependOp();
std::vector<OpDescBind *> AllOps() const; std::vector<OpDesc *> AllOps() const;
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDescBind *Op(int idx) { return ops_.at(idx).get(); } OpDesc *Op(int idx) { return ops_.at(idx).get(); }
void Flush(); void Flush();
proto::BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; } ProgramDesc *Program() { return this->prog_; }
private: private:
void ClearPBOps(); void ClearPBOps();
void ClearPBVars(); void ClearPBVars();
private: private:
ProgramDescBind *prog_; // not_own ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own proto::BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_; std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_; std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
DISABLE_COPY_AND_ASSIGN(BlockDescBind); DISABLE_COPY_AND_ASSIGN(BlockDesc);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -106,10 +106,10 @@ template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = []( info->grad_op_maker_ = [](
const OpDescBind& fwd_op, const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block) { const std::vector<BlockDesc*>& grad_block) {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block); T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker(); return maker();
}; };
@ -119,7 +119,7 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDescBind& fwd_op, BlockDescBind* block) { info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
T inference; T inference;
inference(fwd_op, block); inference(fwd_op, block);
}; };

@ -64,7 +64,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
} }
} }
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope) { bool create_local_scope) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)

@ -114,7 +114,7 @@ class Executor {
* ProgramDesc * ProgramDesc
* Scope * Scope
*/ */
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true); void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true);
private: private:
std::vector<const platform::DeviceContext*> device_contexts_; std::vector<const platform::DeviceContext*> device_contexts_;

@ -25,18 +25,16 @@ namespace framework {
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase( explicit GradOpDescMakerBase(
const OpDescBind& fwd_op, const OpDesc& fwd_op, const std::unordered_set<std::string>& no_grad_set,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block = const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>())
std::vector<BlockDescBind*>())
: fwd_op_(fwd_op), : fwd_op_(fwd_op),
no_grad_set_(no_grad_set), no_grad_set_(no_grad_set),
grad_to_var_(grad_to_var), grad_to_var_(grad_to_var),
grad_block_(grad_block) {} grad_block_(grad_block) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
protected: protected:
std::vector<std::string> InputGrad(const std::string& name, std::vector<std::string> InputGrad(const std::string& name,
@ -105,26 +103,26 @@ class GradOpDescMakerBase {
std::string ForwardOpType() const { return this->fwd_op_.Type(); } std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private: private:
const OpDescBind& fwd_op_; const OpDesc& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_; const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_; std::unordered_map<std::string, std::string>* grad_to_var_;
protected: protected:
std::vector<BlockDescBind*> grad_block_; std::vector<BlockDesc*> grad_block_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const { std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply()); retv.emplace_back(this->Apply());
return retv; return retv;
} }
protected: protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0; virtual std::unique_ptr<OpDesc> Apply() const = 0;
}; };
template <bool DropEmptyIG = true> template <bool DropEmptyIG = true>
@ -133,8 +131,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual std::unique_ptr<OpDescBind> Apply() const { virtual std::unique_ptr<OpDesc> Apply() const {
auto* grad = new OpDescBind(); auto* grad = new OpDesc();
grad->SetType(this->GradOpType()); grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
@ -150,7 +148,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad); return std::unique_ptr<OpDesc>(grad);
} }
virtual std::string GradOpType() const { virtual std::string GradOpType() const {
@ -161,7 +159,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
class EmptyGradOpMaker : public GradOpDescMakerBase { class EmptyGradOpMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override { std::vector<std::unique_ptr<OpDesc>> operator()() const override {
return {}; return {};
} }
}; };

@ -25,12 +25,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDescBind; class OpDesc;
class BlockDescBind; class BlockDesc;
class CompileTimeInferShapeContext : public InferShapeContext { class CompileTimeInferShapeContext : public InferShapeContext {
public: public:
CompileTimeInferShapeContext(const OpDescBind &op, CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
const BlockDescBind &block);
bool HasInput(const std::string &name) const override; bool HasInput(const std::string &name) const override;
@ -76,13 +75,12 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override; void SetDim(const std::string &name, const DDim &dim) override;
const OpDescBind &op_; const OpDesc &op_;
const BlockDescBind &block_; const BlockDesc &block_;
}; };
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const AttributeMap &attrs) {
const AttributeMap &attrs) {
desc_.set_type(type); desc_.set_type(type);
inputs_ = inputs; inputs_ = inputs;
outputs_ = outputs; outputs_ = outputs;
@ -90,7 +88,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog) OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
@ -126,20 +124,19 @@ OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
} }
} }
proto::OpDesc *OpDescBind::Proto() { proto::OpDesc *OpDesc::Proto() {
Flush(); Flush();
return &desc_; return &desc_;
} }
const std::vector<std::string> &OpDescBind::Input( const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
const std::string &name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name, PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
Type()); Type());
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::InputArgumentNames() const { std::vector<std::string> OpDesc::InputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->inputs_) { for (auto &ipt : this->inputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
@ -147,21 +144,20 @@ std::vector<std::string> OpDescBind::InputArgumentNames() const {
return retv; return retv;
} }
void OpDescBind::SetInput(const std::string &param_name, void OpDesc::SetInput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
inputs_[param_name] = args; inputs_[param_name] = args;
} }
const std::vector<std::string> &OpDescBind::Output( const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
const std::string &name) const {
auto it = outputs_.find(name); auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s", PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
name, Type()); name, Type());
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::OutputArgumentNames() const { std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->outputs_) { for (auto &ipt : this->outputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
@ -169,19 +165,19 @@ std::vector<std::string> OpDescBind::OutputArgumentNames() const {
return retv; return retv;
} }
void OpDescBind::SetOutput(const std::string &param_name, void OpDesc::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
this->outputs_[param_name] = args; this->outputs_[param_name] = args;
} }
proto::AttrType OpDescBind::GetAttrType(const std::string &name) const { proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<proto::AttrType>(it->second.which() - 1); return static_cast<proto::AttrType>(it->second.which() - 1);
} }
std::vector<std::string> OpDescBind::AttrNames() const { std::vector<std::string> OpDesc::AttrNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
retv.reserve(attrs_.size()); retv.reserve(attrs_.size());
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
@ -190,41 +186,39 @@ std::vector<std::string> OpDescBind::AttrNames() const {
return retv; return retv;
} }
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v; this->attrs_[name] = v;
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDesc::SetBlockAttr(const std::string &name, BlockDesc &block) {
this->attrs_[name] = &block; this->attrs_[name] = &block;
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetAttrMap( void OpDesc::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
need_update_ = true; need_update_ = true;
} }
Attribute OpDescBind::GetAttr(const std::string &name) const { Attribute OpDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second; return it->second;
} }
int OpDescBind::GetBlockAttr(const std::string &name) const { int OpDesc::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDescBind *>(it->second)->ID(); return boost::get<BlockDesc *>(it->second)->ID();
} }
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
const {
return attrs_; return attrs_;
} }
void OpDescBind::Rename(const std::string &old_name, void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
const std::string &new_name) {
for (auto &input : inputs_) { for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
} }
@ -235,7 +229,7 @@ void OpDescBind::Rename(const std::string &old_name,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::RenameOutput(const std::string &old_name, void OpDesc::RenameOutput(const std::string &old_name,
const std::string &new_name) { const std::string &new_name) {
for (auto &output : outputs_) { for (auto &output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name, std::replace(output.second.begin(), output.second.end(), old_name,
@ -244,7 +238,7 @@ void OpDescBind::RenameOutput(const std::string &old_name,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::RenameInput(const std::string &old_name, void OpDesc::RenameInput(const std::string &old_name,
const std::string &new_name) { const std::string &new_name) {
for (auto &input : inputs_) { for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
@ -278,7 +272,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
void OpDescBind::Flush() { void OpDesc::Flush() {
if (need_update_) { if (need_update_) {
this->desc_.mutable_inputs()->Clear(); this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
@ -330,7 +324,7 @@ static void InitInferShapeFuncs() {
}); });
} }
void OpDescBind::CheckAttrs() { void OpDesc::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(), PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted."); "CheckAttr() can not be called before type is setted.");
auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
@ -342,7 +336,7 @@ void OpDescBind::CheckAttrs() {
checker->Check(attrs_); checker->Check(attrs_);
} }
void OpDescBind::InferShape(const BlockDescBind &block) const { void OpDesc::InferShape(const BlockDesc &block) const {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs(); InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
@ -365,7 +359,7 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
infer_shape(&ctx); infer_shape(&ctx);
} }
void OpDescBind::InferVarType(BlockDescBind *block) const { void OpDesc::InferVarType(BlockDesc *block) const {
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); info.infer_var_type_(*this, block);
@ -384,7 +378,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
} }
CompileTimeInferShapeContext::CompileTimeInferShapeContext( CompileTimeInferShapeContext::CompileTimeInferShapeContext(
const OpDescBind &op, const BlockDescBind &block) const OpDesc &op, const BlockDesc &block)
: op_(op), block_(block) {} : op_(op), block_(block) {}
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const { bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {

@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDesc;
class ProgramDescBind; class ProgramDesc;
class OpDescBind { class OpDesc {
public: public:
OpDescBind() {} OpDesc() {}
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog); OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
proto::OpDesc *Proto(); proto::OpDesc *Proto();
@ -65,7 +65,7 @@ class OpDescBind {
void SetAttr(const std::string &name, const Attribute &v); void SetAttr(const std::string &name, const Attribute &v);
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDesc &block);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
@ -107,9 +107,9 @@ class OpDescBind {
void CheckAttrs(); void CheckAttrs();
void InferShape(const BlockDescBind &block) const; void InferShape(const BlockDesc &block) const;
void InferVarType(BlockDescBind *block) const; void InferVarType(BlockDesc *block) const;
void MarkAsTarget() { desc_.set_is_target(true); } void MarkAsTarget() { desc_.set_is_target(true); }

@ -47,7 +47,7 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) { const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be" VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) " "used in unit tests. Use CreateOp(const OpDesc& op_desc) "
"instead."; "instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
@ -59,7 +59,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(),
op_desc.GetAttrMap()); op_desc.GetAttrMap());
} }

@ -79,7 +79,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
}; };
template <typename PlaceType, bool at_end, size_t I, typename... KernelType> template <typename PlaceType, bool at_end, size_t I, typename... KernelType>

@ -18,49 +18,49 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
auto *b = desc_.add_blocks(); auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(desc_.blocks_size() - 1); b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDesc(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
proto::ProgramDesc *ProgramDescBind::Proto() { proto::ProgramDesc *ProgramDesc::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return &desc_; return &desc_;
} }
ProgramDescBind::ProgramDescBind() { ProgramDesc::ProgramDesc() {
auto *block = desc_.mutable_blocks()->Add(); auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex); block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block)); blocks_.emplace_back(new BlockDesc(this, block));
} }
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ProgramDesc::ProgramDesc(const ProgramDesc &o) {
desc_ = o.desc_; desc_ = o.desc_;
for (int i = 0; i < desc_.blocks_size(); ++i) { for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
} }
} }
ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
} }
ProgramDescBind::ProgramDescBind(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
} }

@ -23,23 +23,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDesc;
class ProgramDescBind { class ProgramDesc {
public: public:
ProgramDescBind(); ProgramDesc();
explicit ProgramDescBind(const proto::ProgramDesc &desc); explicit ProgramDesc(const proto::ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o); ProgramDesc(const ProgramDesc &o);
explicit ProgramDescBind(const std::string &binary_str); explicit ProgramDesc(const std::string &binary_str);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDesc *AppendBlock(const BlockDesc &parent);
BlockDescBind *MutableBlock(size_t idx) { return blocks_[idx].get(); } BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); }
const BlockDescBind &Block(size_t idx) const { return *blocks_[idx]; } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
@ -48,7 +48,7 @@ class ProgramDescBind {
private: private:
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle

@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(ProgramDesc, copy_ctor) { TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
@ -42,12 +42,12 @@ TEST(ProgramDesc, copy_ctor) {
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program); ProgramDesc program_copy(program);
auto* global_block_copy = program_copy.MutableBlock(0); auto* global_block_copy = program_copy.MutableBlock(0);
ASSERT_NE(global_block, global_block_copy); ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDesc* var_before) {
ASSERT_TRUE(global_block_copy->HasVar(name)); ASSERT_TRUE(global_block_copy->HasVar(name));
auto* copy = global_block_copy->Var(name); auto* copy = global_block_copy->Var(name);
ASSERT_NE(copy, var_before); ASSERT_NE(copy, var_before);
@ -81,7 +81,7 @@ TEST(ProgramDesc, copy_ctor) {
} }
TEST(ProgramDescBind, serialize_and_deserialize) { TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin; ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
@ -107,11 +107,11 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
std::string binary_str; std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str); program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str); ProgramDesc program_restored(binary_str);
auto* global_block_restored = program_restored.MutableBlock(0); auto* global_block_restored = program_restored.MutableBlock(0);
ASSERT_NE(global_block, global_block_restored); ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDesc* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name)); ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name); auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before); ASSERT_NE(restored, var_before);

@ -29,7 +29,7 @@ namespace ops = paddle::operators;
void AddOp(const std::string &type, const f::VariableNameMap &inputs, void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs, const f::VariableNameMap &outputs, f::AttributeMap attrs,
paddle::framework::BlockDescBind *block) { paddle::framework::BlockDesc *block) {
// insert output // insert output
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
@ -51,8 +51,8 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
} }
TEST(Prune, one_operator) { TEST(Prune, one_operator) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block); block);
@ -69,8 +69,8 @@ TEST(Prune, one_operator) {
} }
TEST(Prune, forward) { TEST(Prune, forward) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block); block);
@ -92,8 +92,8 @@ TEST(Prune, forward) {
} }
TEST(Prune, multi_input_op) { TEST(Prune, multi_input_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, f::AttributeMap{},
block); block);
@ -113,8 +113,8 @@ TEST(Prune, multi_input_op) {
} }
TEST(Prune, multi_output_op) { TEST(Prune, multi_output_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block); f::AttributeMap{}, block);
@ -132,8 +132,8 @@ TEST(Prune, multi_output_op) {
} }
TEST(Prune, multi_target) { TEST(Prune, multi_target) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block); f::AttributeMap{}, block);

@ -25,11 +25,9 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class OpDescBind; class OpDesc;
class BlockDescBind;
class BlockDesc;
class InferShapeContext; class InferShapeContext;
class BlockDescBind; class BlockDesc;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
@ -37,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDescBind*>; std::vector<bool>, BlockDesc*>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
@ -45,13 +43,13 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>( using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/, const OpDesc&, const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/, std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDescBind*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/, using InferVarTypeFN =
BlockDescBind* /*block*/)>; std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;

@ -18,29 +18,27 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
proto::VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); } proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); }
void VarDescBind::SetType(proto::VarDesc::VarType type) { void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); }
desc_.set_type(type);
}
void VarDescBind::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDescBind::SetDataType(proto::DataType data_type) { void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
std::vector<int64_t> VarDescBind::Shape() const { std::vector<int64_t> VarDesc::Shape() const {
return RepeatedToVector(tensor_desc().dims()); return RepeatedToVector(tensor_desc().dims());
} }
proto::DataType VarDescBind::GetDataType() const { proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
void VarDescBind::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_lod_tensor()->set_lod_level(lod_level);
@ -54,7 +52,7 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
} }
} }
int32_t VarDescBind::GetLodLevel() const { int32_t VarDesc::GetLodLevel() const {
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.lod_tensor().lod_level();
@ -66,7 +64,7 @@ int32_t VarDescBind::GetLodLevel() const {
} }
} }
const proto::TensorDesc &VarDescBind::tensor_desc() const { const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
@ -80,7 +78,7 @@ const proto::TensorDesc &VarDescBind::tensor_desc() const {
} }
} }
proto::TensorDesc *VarDescBind::mutable_tensor_desc() { proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type"); "invoke MutableTensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {

@ -53,14 +53,14 @@ inline void VectorToRepeated(const std::vector<bool> &vec,
} }
} }
class VarDescBind { class VarDesc {
public: public:
explicit VarDescBind(const std::string &name) { explicit VarDesc(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR); desc_.set_type(proto::VarDesc::LOD_TENSOR);
} }
explicit VarDescBind(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }

@ -21,8 +21,7 @@ namespace framework {
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
virtual void operator()(const OpDescBind& op_desc, virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
BlockDescBind* block) const = 0;
}; };
} // namespace framework } // namespace framework

@ -33,8 +33,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDescBind &op_desc, void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
BlockDescBind *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarDesc::SELECTED_ROWS;
@ -62,7 +61,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDescBind prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
@ -85,7 +84,7 @@ TEST(InferVarType, sum_op) {
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDescBind prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});

@ -149,14 +149,14 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("lod_tensor_to_array"); grad_op->SetType("lod_tensor_to_array");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetInput("RankTable", Input("RankTable")); grad_op->SetInput("RankTable", Input("RankTable"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };

@ -121,12 +121,12 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDescBind(); auto *op = new framework::OpDesc();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", OutputGrad("Out")); op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X")); op->SetOutput("Out", InputGrad("X"));
return std::unique_ptr<framework::OpDescBind>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };

@ -119,8 +119,8 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
} }

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save