|
|
@ -21,49 +21,81 @@ namespace paddle {
|
|
|
|
namespace imperative {
|
|
|
|
namespace imperative {
|
|
|
|
namespace jit {
|
|
|
|
namespace jit {
|
|
|
|
|
|
|
|
|
|
|
|
void ProgramDescTracer::SetNamePrefix(const std::string &name_prefix) {
|
|
|
|
// A helper class to generate unique name for each non-persistable var
|
|
|
|
name_prefix_ = name_prefix;
|
|
|
|
class UniqueBlockVarGenerator {
|
|
|
|
}
|
|
|
|
public:
|
|
|
|
|
|
|
|
UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
|
|
|
|
|
|
|
|
framework::BlockDesc *block);
|
|
|
|
|
|
|
|
|
|
|
|
void ProgramDescTracer::SetFeedVars(
|
|
|
|
std::string NameOf(const std::weak_ptr<VarBase> &var,
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
|
|
|
const std::string &prefix);
|
|
|
|
std::vector<std::string> feed_names) {
|
|
|
|
|
|
|
|
feed_vars_.clear();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (feed_names.empty()) {
|
|
|
|
private:
|
|
|
|
feed_names.reserve(feed_vars.size());
|
|
|
|
void InsertNewVarInBlock(const std::weak_ptr<VarBase> &var,
|
|
|
|
for (auto &var : feed_vars) {
|
|
|
|
const framework::VarDesc &ref_desc,
|
|
|
|
feed_names.emplace_back(var->Name());
|
|
|
|
const std::string &name);
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(feed_names.size(), feed_vars.size(),
|
|
|
|
private:
|
|
|
|
"The feeded variable names number must be equal to the "
|
|
|
|
const VarDescMetaMap &all_vars_;
|
|
|
|
"feeded variable number");
|
|
|
|
framework::BlockDesc *block_;
|
|
|
|
|
|
|
|
std::unordered_map<std::string, size_t> counter_;
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < feed_names.size(); ++i) {
|
|
|
|
std::map<std::weak_ptr<VarBase>, std::string,
|
|
|
|
feed_vars_[feed_vars[i]] = feed_names[i];
|
|
|
|
std::owner_less<std::weak_ptr<VarBase>>>
|
|
|
|
|
|
|
|
var_to_name_;
|
|
|
|
|
|
|
|
std::unordered_set<std::string> existing_names_;
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
UniqueBlockVarGenerator::UniqueBlockVarGenerator(const VarDescMetaMap &all_vars,
|
|
|
|
|
|
|
|
framework::BlockDesc *block)
|
|
|
|
|
|
|
|
: all_vars_(all_vars), block_(block) {
|
|
|
|
|
|
|
|
for (auto &var_pair : all_vars_) {
|
|
|
|
|
|
|
|
auto *var_desc = var_pair.second.get();
|
|
|
|
|
|
|
|
if (var_desc->Persistable()) {
|
|
|
|
|
|
|
|
InsertNewVarInBlock(var_pair.first, *var_desc, var_desc->Name());
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ProgramDescTracer::SetFetchVars(
|
|
|
|
std::string UniqueBlockVarGenerator::NameOf(const std::weak_ptr<VarBase> &var,
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
|
|
|
const std::string &prefix) {
|
|
|
|
std::vector<std::string> fetch_names) {
|
|
|
|
auto all_vars_iter = all_vars_.find(var);
|
|
|
|
fetch_vars_.clear();
|
|
|
|
PADDLE_ENFORCE_EQ(all_vars_iter != all_vars_.end(), true,
|
|
|
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
if (fetch_names.empty()) {
|
|
|
|
"Variable is not found in UniqueBlockVarGenerator"));
|
|
|
|
fetch_names.reserve(fetch_vars.size());
|
|
|
|
|
|
|
|
for (auto &var : fetch_vars) {
|
|
|
|
auto iter = var_to_name_.find(var);
|
|
|
|
fetch_names.emplace_back(var->Name());
|
|
|
|
if (iter != var_to_name_.end()) {
|
|
|
|
}
|
|
|
|
VLOG(5) << "Return existing var name " << iter->second;
|
|
|
|
|
|
|
|
return iter->second;
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
auto generate_unique_name = [this, &prefix] {
|
|
|
|
|
|
|
|
auto &cnt = counter_[prefix];
|
|
|
|
|
|
|
|
do {
|
|
|
|
|
|
|
|
auto name = prefix + std::to_string(cnt++);
|
|
|
|
|
|
|
|
if (existing_names_.count(name) == 0) {
|
|
|
|
|
|
|
|
return name;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} while (cnt > 0);
|
|
|
|
|
|
|
|
PADDLE_THROW(
|
|
|
|
|
|
|
|
platform::errors::OutOfRange("Too many vars in the program"));
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto unique_name = generate_unique_name();
|
|
|
|
|
|
|
|
VLOG(5) << "Generate new var name " << unique_name;
|
|
|
|
|
|
|
|
InsertNewVarInBlock(var, *(all_vars_iter->second), unique_name);
|
|
|
|
|
|
|
|
return unique_name;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(fetch_names.size(), fetch_vars.size(),
|
|
|
|
void UniqueBlockVarGenerator::InsertNewVarInBlock(
|
|
|
|
"The fetched variable names number must be equal to the "
|
|
|
|
const std::weak_ptr<VarBase> &var, const framework::VarDesc &var_desc,
|
|
|
|
"fetched variable number");
|
|
|
|
const std::string &name) {
|
|
|
|
for (size_t i = 0; i < fetch_names.size(); ++i) {
|
|
|
|
var_to_name_[var] = name;
|
|
|
|
fetch_vars_[fetch_vars[i]] = fetch_names[i];
|
|
|
|
existing_names_.insert(name);
|
|
|
|
}
|
|
|
|
auto *new_var_desc = block_->Var(name);
|
|
|
|
|
|
|
|
*new_var_desc = var_desc;
|
|
|
|
|
|
|
|
new_var_desc->SetName(name);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ProgramDescTracer::InsertOp(const std::string &type,
|
|
|
|
void ProgramDescTracer::InsertOp(const std::string &type,
|
|
|
@ -85,70 +117,24 @@ void ProgramDescTracer::InsertOp(const std::string &type,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
|
|
|
TracedProgramTuple ProgramDescTracer::CreateProgramDesc(
|
|
|
|
const {
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &feed_vars,
|
|
|
|
|
|
|
|
const std::string &feed_prefix,
|
|
|
|
|
|
|
|
const std::vector<std::shared_ptr<VarBase>> &fetch_vars,
|
|
|
|
|
|
|
|
const std::string &fetch_prefix, const std::string &tmp_prefix) const {
|
|
|
|
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
|
|
|
|
std::unique_ptr<framework::ProgramDesc> prog(new framework::ProgramDesc());
|
|
|
|
auto *block = prog->MutableBlock(0);
|
|
|
|
auto *block = prog->MutableBlock(0);
|
|
|
|
|
|
|
|
|
|
|
|
size_t var_num = vars_.size();
|
|
|
|
UniqueBlockVarGenerator generator(vars_, block);
|
|
|
|
std::vector<framework::VarDesc *> var_descs(var_num, nullptr);
|
|
|
|
|
|
|
|
std::unordered_map<framework::VarDesc *, std::weak_ptr<VarBase>>
|
|
|
|
|
|
|
|
var_desc_to_var_base;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &pair : vars_) {
|
|
|
|
|
|
|
|
size_t var_id = pair.second.first;
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_LT(var_id, var_num);
|
|
|
|
|
|
|
|
var_descs[var_id] = pair.second.second.get();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var_descs[var_id]);
|
|
|
|
|
|
|
|
var_desc_to_var_base[var_descs[var_id]] = pair.first;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::unordered_set<std::string> existing_var_names;
|
|
|
|
|
|
|
|
for (auto *var_desc : var_descs) {
|
|
|
|
|
|
|
|
if (var_desc->Persistable()) {
|
|
|
|
|
|
|
|
existing_var_names.insert(var_desc->Name());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &pair : feed_vars_) {
|
|
|
|
|
|
|
|
existing_var_names.insert(pair.second);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &pair : fetch_vars_) {
|
|
|
|
std::vector<std::string> feed_var_names;
|
|
|
|
existing_var_names.insert(pair.second);
|
|
|
|
for (auto &feed_var : feed_vars) {
|
|
|
|
|
|
|
|
feed_var_names.emplace_back(generator.NameOf(feed_var, feed_prefix));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
size_t counter = 0;
|
|
|
|
std::vector<std::string> fetch_var_names;
|
|
|
|
auto generate_unique_name = [&]() -> std::string {
|
|
|
|
for (auto &fetch_var : fetch_vars) {
|
|
|
|
do {
|
|
|
|
fetch_var_names.emplace_back(generator.NameOf(fetch_var, fetch_prefix));
|
|
|
|
auto name = name_prefix_ + std::to_string(counter++);
|
|
|
|
|
|
|
|
if (existing_var_names.count(name) == 0) {
|
|
|
|
|
|
|
|
existing_var_names.insert(name);
|
|
|
|
|
|
|
|
return name;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} while (counter > 0);
|
|
|
|
|
|
|
|
PADDLE_THROW("Too many vars in the program");
|
|
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::map<std::weak_ptr<VarBase>, std::string,
|
|
|
|
|
|
|
|
std::owner_less<std::weak_ptr<VarBase>>>
|
|
|
|
|
|
|
|
var_to_name;
|
|
|
|
|
|
|
|
for (auto *var_desc : var_descs) {
|
|
|
|
|
|
|
|
auto var_name = var_desc->Name();
|
|
|
|
|
|
|
|
PADDLE_ENFORCE_EQ(var_desc_to_var_base.count(var_desc), 1);
|
|
|
|
|
|
|
|
std::weak_ptr<VarBase> var_base = var_desc_to_var_base.at(var_desc);
|
|
|
|
|
|
|
|
if (feed_vars_.count(var_base) > 0) {
|
|
|
|
|
|
|
|
var_name = feed_vars_.at(var_base);
|
|
|
|
|
|
|
|
} else if (fetch_vars_.count(var_base) > 0) {
|
|
|
|
|
|
|
|
var_name = fetch_vars_.at(var_base);
|
|
|
|
|
|
|
|
} else if (!var_desc->Persistable()) {
|
|
|
|
|
|
|
|
var_name = generate_unique_name();
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
auto *new_var_desc = block->Var(var_name);
|
|
|
|
|
|
|
|
*new_var_desc = *var_desc;
|
|
|
|
|
|
|
|
new_var_desc->SetName(std::move(var_name));
|
|
|
|
|
|
|
|
var_to_name[var_base] = new_var_desc->Name();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto &op : ops_) {
|
|
|
|
for (auto &op : ops_) {
|
|
|
@ -160,10 +146,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
|
|
|
std::vector<std::string> names;
|
|
|
|
std::vector<std::string> names;
|
|
|
|
names.reserve(pair.second.size());
|
|
|
|
names.reserve(pair.second.size());
|
|
|
|
for (auto &var : pair.second) {
|
|
|
|
for (auto &var : pair.second) {
|
|
|
|
auto iter = var_to_name.find(var);
|
|
|
|
names.emplace_back(generator.NameOf(var, tmp_prefix));
|
|
|
|
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
|
|
|
|
|
|
|
"Cannot find input variable");
|
|
|
|
|
|
|
|
names.emplace_back(iter->second);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
op_desc->SetInput(pair.first, std::move(names));
|
|
|
|
op_desc->SetInput(pair.first, std::move(names));
|
|
|
@ -173,10 +156,7 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
|
|
|
std::vector<std::string> names;
|
|
|
|
std::vector<std::string> names;
|
|
|
|
names.reserve(pair.second.size());
|
|
|
|
names.reserve(pair.second.size());
|
|
|
|
for (auto &var : pair.second) {
|
|
|
|
for (auto &var : pair.second) {
|
|
|
|
auto iter = var_to_name.find(var);
|
|
|
|
names.emplace_back(generator.NameOf(var, tmp_prefix));
|
|
|
|
PADDLE_ENFORCE_EQ(iter != var_to_name.end(), true,
|
|
|
|
|
|
|
|
"Cannot find output variable");
|
|
|
|
|
|
|
|
names.emplace_back(iter->second);
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
op_desc->SetOutput(pair.first, std::move(names));
|
|
|
|
op_desc->SetOutput(pair.first, std::move(names));
|
|
|
@ -184,7 +164,8 @@ std::unique_ptr<framework::ProgramDesc> ProgramDescTracer::CreateProgramDesc()
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
prog->Flush();
|
|
|
|
prog->Flush();
|
|
|
|
return prog;
|
|
|
|
return std::make_tuple(std::move(prog), std::move(feed_var_names),
|
|
|
|
|
|
|
|
std::move(fetch_var_names));
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void ProgramDescTracer::InsertVarIfNotExist(
|
|
|
|
void ProgramDescTracer::InsertVarIfNotExist(
|
|
|
@ -192,10 +173,8 @@ void ProgramDescTracer::InsertVarIfNotExist(
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(new_var);
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(new_var);
|
|
|
|
if (vars_.count(new_var) != 0) return;
|
|
|
|
if (vars_.count(new_var) != 0) return;
|
|
|
|
|
|
|
|
|
|
|
|
size_t var_id = vars_.size();
|
|
|
|
|
|
|
|
auto new_var_desc = new framework::VarDesc("");
|
|
|
|
auto new_var_desc = new framework::VarDesc("");
|
|
|
|
vars_[new_var] =
|
|
|
|
vars_[new_var].reset(new_var_desc);
|
|
|
|
std::make_pair(var_id, std::unique_ptr<framework::VarDesc>(new_var_desc));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (new_var->Persistable()) {
|
|
|
|
if (new_var->Persistable()) {
|
|
|
|
new_var_desc->SetName(new_var->Name());
|
|
|
|
new_var_desc->SetName(new_var->Name());
|
|
|
@ -225,9 +204,6 @@ void ProgramDescTracer::InsertVarIfNotExist(
|
|
|
|
void ProgramDescTracer::Reset() {
|
|
|
|
void ProgramDescTracer::Reset() {
|
|
|
|
ops_.clear();
|
|
|
|
ops_.clear();
|
|
|
|
vars_.clear();
|
|
|
|
vars_.clear();
|
|
|
|
feed_vars_.clear();
|
|
|
|
|
|
|
|
fetch_vars_.clear();
|
|
|
|
|
|
|
|
name_prefix_.clear();
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
} // namespace jit
|
|
|
|
} // namespace jit
|
|
|
|