|
|
|
@ -113,10 +113,11 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
|
|
|
|
|
// and feed_holder_name. Raise exception when any mismatch is found.
|
|
|
|
|
// Return true if the block has feed operators and holder of matching info.
|
|
|
|
|
static bool has_feed_operators(
|
|
|
|
|
BlockDesc* block, std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
const BlockDesc& block,
|
|
|
|
|
std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
const std::string& feed_holder_name) {
|
|
|
|
|
size_t feed_count = 0;
|
|
|
|
|
for (auto* op : block->AllOps()) {
|
|
|
|
|
for (auto* op : block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
feed_count++;
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
|
|
|
|
@ -135,7 +136,7 @@ static bool has_feed_operators(
|
|
|
|
|
"The number of feed operators should match 'feed_targets'");
|
|
|
|
|
|
|
|
|
|
// When feed operator are present, so should be feed_holder
|
|
|
|
|
auto var = block->FindVar(feed_holder_name);
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
@ -153,10 +154,10 @@ static bool has_feed_operators(
|
|
|
|
|
// and fetch_holder_name. Raise exception when any mismatch is found.
|
|
|
|
|
// Return true if the block has fetch operators and holder of matching info.
|
|
|
|
|
static bool has_fetch_operators(
|
|
|
|
|
BlockDesc* block, std::map<std::string, LoDTensor*>& fetch_targets,
|
|
|
|
|
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
|
|
|
|
|
const std::string& fetch_holder_name) {
|
|
|
|
|
size_t fetch_count = 0;
|
|
|
|
|
for (auto* op : block->AllOps()) {
|
|
|
|
|
for (auto* op : block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
fetch_count++;
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
|
|
|
|
@ -175,7 +176,7 @@ static bool has_fetch_operators(
|
|
|
|
|
"The number of fetch operators should match 'fetch_targets'");
|
|
|
|
|
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder
|
|
|
|
|
auto var = block->FindVar(fetch_holder_name);
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
@ -192,10 +193,19 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
const std::string& feed_holder_name,
|
|
|
|
|
const std::string& fetch_holder_name) {
|
|
|
|
|
platform::RecordBlock b(kProgramId);
|
|
|
|
|
auto* copy_program = new ProgramDesc(program);
|
|
|
|
|
bool has_feed_ops =
|
|
|
|
|
has_feed_operators(program.Block(0), feed_targets, feed_holder_name);
|
|
|
|
|
bool has_fetch_ops =
|
|
|
|
|
has_fetch_operators(program.Block(0), fetch_targets, fetch_holder_name);
|
|
|
|
|
|
|
|
|
|
ProgramDesc* copy_program = const_cast<ProgramDesc*>(&program);
|
|
|
|
|
if (!has_feed_ops || !has_fetch_ops) {
|
|
|
|
|
copy_program = std::unique_ptr<ProgramDesc>(new ProgramDesc(program)).get();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto* global_block = copy_program->MutableBlock(0);
|
|
|
|
|
|
|
|
|
|
if (!has_feed_operators(global_block, feed_targets, feed_holder_name)) {
|
|
|
|
|
if (!has_feed_ops) {
|
|
|
|
|
// create feed_holder variable
|
|
|
|
|
auto* feed_holder = global_block->Var(feed_holder_name);
|
|
|
|
|
feed_holder->SetType(proto::VarType::FEED_MINIBATCH);
|
|
|
|
@ -228,7 +238,7 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!has_fetch_operators(global_block, fetch_targets, fetch_holder_name)) {
|
|
|
|
|
if (!has_fetch_ops) {
|
|
|
|
|
// create fetch_holder variable
|
|
|
|
|
auto* fetch_holder = global_block->Var(fetch_holder_name);
|
|
|
|
|
fetch_holder->SetType(proto::VarType::FETCH_LIST);
|
|
|
|
@ -262,8 +272,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
GetFetchVariable(*scope, fetch_holder_name, idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
delete copy_program;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ExecutorPrepareContext* Executor::Prepare(const ProgramDesc& program,
|
|
|
|
@ -313,9 +321,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
|
|
|
|
|
} // if (create_vars)
|
|
|
|
|
|
|
|
|
|
for (auto& op : ctx->ops_) {
|
|
|
|
|
VLOG(4) << place_ << " " << op->DebugStringEx(local_scope);
|
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
|
VLOG(3) << place_ << " " << op->DebugStringEx(local_scope);
|
|
|
|
|
op->Run(*local_scope, place_);
|
|
|
|
|
|
|
|
|
|
if (FLAGS_benchmark) {
|
|
|
|
|
VLOG(2) << "Memory used after operator " + op->Type() + " running: "
|
|
|
|
|