|
|
|
|
@ -129,13 +129,15 @@ static bool has_feed_operators(
|
|
|
|
|
feed_count, feed_targets.size(),
|
|
|
|
|
"The number of feed operators should match 'feed_targets'");
|
|
|
|
|
|
|
|
|
|
// When feed operator are present, so should be feed_holder
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
if (!feed_holder_name.empty()) {
|
|
|
|
|
// When feed operator are present, so should be feed_holder
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return feed_count > 0;
|
|
|
|
|
@ -169,13 +171,15 @@ static bool has_fetch_operators(
|
|
|
|
|
fetch_count, fetch_targets.size(),
|
|
|
|
|
"The number of fetch operators should match 'fetch_targets'");
|
|
|
|
|
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
if (!fetch_holder_name.empty()) {
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fetch_count > 0;
|
|
|
|
|
@ -222,16 +226,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// map the data of feed_targets to feed_holder
|
|
|
|
|
for (auto* op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
|
|
|
|
|
idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!has_fetch_ops) {
|
|
|
|
|
// create fetch_holder variable
|
|
|
|
|
auto* fetch_holder = global_block->Var(fetch_holder_name);
|
|
|
|
|
@ -255,17 +249,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Run(*copy_program, scope, 0, create_vars, create_vars);
|
|
|
|
|
|
|
|
|
|
// obtain the data of fetch_targets from fetch_holder
|
|
|
|
|
for (auto* op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
*fetch_targets[fetch_target_name] =
|
|
|
|
|
GetFetchVariable(*scope, fetch_holder_name, idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto ctx = Prepare(*copy_program, 0);
|
|
|
|
|
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets,
|
|
|
|
|
feed_holder_name, fetch_holder_name, create_vars);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
|
|
|
|
|
@ -343,5 +329,43 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Executor::RunPreparedContext(
|
|
|
|
|
ExecutorPrepareContext* ctx, Scope* scope,
|
|
|
|
|
std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
std::map<std::string, LoDTensor*>& fetch_targets,
|
|
|
|
|
const std::string& feed_holder_name, const std::string& fetch_holder_name,
|
|
|
|
|
bool create_vars) {
|
|
|
|
|
auto& global_block = ctx->prog_.Block(ctx->block_id_);
|
|
|
|
|
|
|
|
|
|
// map the data of feed_targets to feed_holder
|
|
|
|
|
for (auto* op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
|
PADDLE_ENFORCE(feed_targets.find(feed_target_name) != feed_targets.end(),
|
|
|
|
|
"Variable %s is not feeded.");
|
|
|
|
|
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
|
|
|
|
|
idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RunPreparedContext(ctx, scope, create_vars, create_vars);
|
|
|
|
|
|
|
|
|
|
// obtain the data of fetch_targets from fetch_holder
|
|
|
|
|
for (auto* op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
|
|
|
|
|
"Variable %s is not fetched.");
|
|
|
|
|
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
*fetch_targets[fetch_target_name] =
|
|
|
|
|
GetFetchVariable(*scope, fetch_holder_name, idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|
|