|
|
|
@ -32,23 +32,11 @@ void ReadBinaryFile(const std::string& filename, std::string& contents) {
|
|
|
|
|
inputfs.close();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsParameter(const framework::VarDesc* var,
|
|
|
|
|
const framework::ProgramDesc& main_program) {
|
|
|
|
|
if (var->Persistable()) {
|
|
|
|
|
// There are many unreachable variables in the program
|
|
|
|
|
for (size_t i = 0; i < main_program.Size(); ++i) {
|
|
|
|
|
const framework::BlockDesc& block = main_program.Block(i);
|
|
|
|
|
for (auto* op : block.AllOps()) {
|
|
|
|
|
if (op->Type() == framework::kFeedOpType) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto input_argument_name : op->InputArgumentNames()) {
|
|
|
|
|
if (input_argument_name == var->Name()) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
bool IsPersistable(const framework::VarDesc* var) {
|
|
|
|
|
if (var->Persistable() &&
|
|
|
|
|
var->GetType() != framework::proto::VarDesc::FEED_MINIBATCH &&
|
|
|
|
|
var->GetType() != framework::proto::VarDesc::FETCH_LIST) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
@ -65,8 +53,8 @@ void LoadPersistables(framework::Executor& executor,
|
|
|
|
|
std::vector<std::string> paramlist;
|
|
|
|
|
|
|
|
|
|
for (auto* var : global_block.AllVars()) {
|
|
|
|
|
if (IsParameter(var, main_program)) {
|
|
|
|
|
VLOG(3) << "parameter's name: " << var->Name();
|
|
|
|
|
if (IsPersistable(var)) {
|
|
|
|
|
VLOG(3) << "persistable variable's name: " << var->Name();
|
|
|
|
|
|
|
|
|
|
framework::VarDesc* new_var = load_block->Var(var->Name());
|
|
|
|
|
new_var->SetShape(var->GetShape());
|
|
|
|
@ -101,7 +89,6 @@ void LoadPersistables(framework::Executor& executor,
|
|
|
|
|
|
|
|
|
|
executor.Run(*load_program, &scope, 0, true, true);
|
|
|
|
|
|
|
|
|
|
VLOG(3) << "Ran loading successfully";
|
|
|
|
|
delete load_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|