|
|
|
@ -83,8 +83,8 @@ static void CheckTensorNANOrInf(const std::string& name,
|
|
|
|
|
if (tensor.memory_size() == 0) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (tensor.type().hash_code() != typeid(float).hash_code() &&
|
|
|
|
|
tensor.type().hash_code() != typeid(double).hash_code()) {
|
|
|
|
|
if (tensor.type().hash_code() != typeid(float).hash_code() && // NOLINT
|
|
|
|
|
tensor.type().hash_code() != typeid(double).hash_code()) { // NOLINT
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
|
|
|
|
@ -145,12 +145,13 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
|
|
|
|
|
// Return true if the block has feed operators and holder of matching info.
|
|
|
|
|
static bool has_feed_operators(
|
|
|
|
|
const BlockDesc& block,
|
|
|
|
|
std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
const std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
const std::string& feed_holder_name) {
|
|
|
|
|
size_t feed_count = 0;
|
|
|
|
|
for (auto* op : block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
feed_count++;
|
|
|
|
|
// The input variable's name of feed_op should be feed_holder_name.
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
|
|
|
|
|
"Input to feed op should be '%s'", feed_holder_name);
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
@ -166,13 +167,15 @@ static bool has_feed_operators(
|
|
|
|
|
feed_count, feed_targets.size(),
|
|
|
|
|
"The number of feed operators should match 'feed_targets'");
|
|
|
|
|
|
|
|
|
|
// When feed operator are present, so should be feed_holder
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
if (!feed_holder_name.empty()) {
|
|
|
|
|
// When feed operator are present, so should be feed_holder.
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return feed_count > 0;
|
|
|
|
@ -185,12 +188,14 @@ static bool has_feed_operators(
|
|
|
|
|
// and fetch_holder_name. Raise exception when any mismatch is found.
|
|
|
|
|
// Return true if the block has fetch operators and holder of matching info.
|
|
|
|
|
static bool has_fetch_operators(
|
|
|
|
|
const BlockDesc& block, std::map<std::string, LoDTensor*>& fetch_targets,
|
|
|
|
|
const BlockDesc& block,
|
|
|
|
|
const std::map<std::string, LoDTensor*>& fetch_targets,
|
|
|
|
|
const std::string& fetch_holder_name) {
|
|
|
|
|
size_t fetch_count = 0;
|
|
|
|
|
for (auto* op : block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
fetch_count++;
|
|
|
|
|
// The output variable's name of fetch_op should be fetch_holder_name.
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
|
|
|
|
|
"Output of fetch op should be '%s'", fetch_holder_name);
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
@ -206,13 +211,15 @@ static bool has_fetch_operators(
|
|
|
|
|
fetch_count, fetch_targets.size(),
|
|
|
|
|
"The number of fetch operators should match 'fetch_targets'");
|
|
|
|
|
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
if (!fetch_holder_name.empty()) {
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder.
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return fetch_count > 0;
|
|
|
|
@ -259,16 +266,6 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// map the data of feed_targets to feed_holder
|
|
|
|
|
for (auto* op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
|
|
|
|
|
idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!has_fetch_ops) {
|
|
|
|
|
// create fetch_holder variable
|
|
|
|
|
auto* fetch_holder = global_block->Var(fetch_holder_name);
|
|
|
|
@ -292,17 +289,9 @@ void Executor::Run(const ProgramDesc& program, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Run(*copy_program, scope, 0, create_vars, create_vars);
|
|
|
|
|
|
|
|
|
|
// obtain the data of fetch_targets from fetch_holder
|
|
|
|
|
for (auto* op : global_block->AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
*fetch_targets[fetch_target_name] =
|
|
|
|
|
GetFetchVariable(*scope, fetch_holder_name, idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
auto ctx = Prepare(*copy_program, 0);
|
|
|
|
|
RunPreparedContext(ctx.get(), scope, feed_targets, fetch_targets, create_vars,
|
|
|
|
|
feed_holder_name, fetch_holder_name);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
|
|
|
|
@ -370,5 +359,42 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Executor::RunPreparedContext(
|
|
|
|
|
ExecutorPrepareContext* ctx, Scope* scope,
|
|
|
|
|
std::map<std::string, const LoDTensor*>& feed_targets,
|
|
|
|
|
std::map<std::string, LoDTensor*>& fetch_targets, bool create_vars,
|
|
|
|
|
const std::string& feed_holder_name, const std::string& fetch_holder_name) {
|
|
|
|
|
auto& global_block = ctx->prog_.Block(ctx->block_id_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
has_feed_operators(global_block, feed_targets, feed_holder_name),
|
|
|
|
|
"Program in ExecutorPrepareContext should has feed_ops.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
has_fetch_operators(global_block, fetch_targets, fetch_holder_name),
|
|
|
|
|
"Program in the prepared context should has fetch_ops.");
|
|
|
|
|
|
|
|
|
|
// map the data of feed_targets to feed_holder
|
|
|
|
|
for (auto* op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
SetFeedVariable(scope, *feed_targets[feed_target_name], feed_holder_name,
|
|
|
|
|
idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
RunPreparedContext(ctx, scope, create_vars, create_vars);
|
|
|
|
|
|
|
|
|
|
// obtain the data of fetch_targets from fetch_holder
|
|
|
|
|
for (auto* op : global_block.AllOps()) {
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
int idx = boost::get<int>(op->GetAttr("col"));
|
|
|
|
|
*fetch_targets[fetch_target_name] =
|
|
|
|
|
GetFetchVariable(*scope, fetch_holder_name, idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
} // namespace framework
|
|
|
|
|
} // namespace paddle
|
|
|
|
|