|
|
|
@ -21,6 +21,7 @@ limitations under the License. */
|
|
|
|
|
#include "google/protobuf/io/zero_copy_stream_impl.h"
|
|
|
|
|
#include "google/protobuf/message.h"
|
|
|
|
|
#include "google/protobuf/text_format.h"
|
|
|
|
|
#include "paddle/fluid/framework/data_type.h"
|
|
|
|
|
#include "paddle/fluid/framework/feed_fetch_method.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_rank_table.h"
|
|
|
|
|
#include "paddle/fluid/framework/lod_tensor_array.h"
|
|
|
|
@ -145,8 +146,10 @@ std::shared_ptr<TrainerBase> Executor::InitForDataset(
|
|
|
|
|
VLOG(3) << "Start to RunFromDataset in executor";
|
|
|
|
|
TrainerDesc trainer_desc;
|
|
|
|
|
bool success = trainer_desc.ParseFromString(trainer_desc_str);
|
|
|
|
|
PADDLE_ENFORCE_EQ(success, true, "Fail to parse TrainerDesc from string:\n%s",
|
|
|
|
|
trainer_desc_str.c_str());
|
|
|
|
|
PADDLE_ENFORCE_EQ(success, true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Fail to parse TrainerDesc from string:\n%s",
|
|
|
|
|
trainer_desc_str.c_str()));
|
|
|
|
|
VLOG(3) << "Going to create trainer, trainer class is "
|
|
|
|
|
<< trainer_desc.class_name();
|
|
|
|
|
std::shared_ptr<TrainerBase> trainer;
|
|
|
|
@ -165,8 +168,9 @@ std::shared_ptr<TrainerBase> Executor::InitForDataset(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Executor::RunFromDataset(std::shared_ptr<TrainerBase> trainer) {
|
|
|
|
|
PADDLE_ENFORCE_NE(trainer, nullptr,
|
|
|
|
|
"Trainer is nullptr, invoke InitForDataset first");
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
trainer, platform::errors::InvalidArgument(
|
|
|
|
|
"Trainer is nullptr, invoke InitForDataset first"));
|
|
|
|
|
// training and finalize training
|
|
|
|
|
VLOG(3) << "Trainer starts to run";
|
|
|
|
|
trainer->Run();
|
|
|
|
@ -203,29 +207,41 @@ static bool has_feed_operators(
|
|
|
|
|
if (op->Type() == kFeedOpType) {
|
|
|
|
|
feed_count++;
|
|
|
|
|
// The input variable's name of feed_op should be feed_holder_name.
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Input("X")[0], feed_holder_name,
|
|
|
|
|
"Input to feed op should be '%s'", feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op->Input("X")[0], feed_holder_name,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Input to feed op should be '%s', but received '%s'.",
|
|
|
|
|
feed_holder_name, op->Input("X")[0]));
|
|
|
|
|
std::string feed_target_name = op->Output("Out")[0];
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
feed_targets.find(feed_target_name) != feed_targets.end(),
|
|
|
|
|
"Feed operator output name '%s' cannot be found in 'feed_targets'",
|
|
|
|
|
feed_target_name);
|
|
|
|
|
PADDLE_ENFORCE_NE(feed_targets.find(feed_target_name), feed_targets.end(),
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Feed operator output name '%s' cannot be found in "
|
|
|
|
|
"'feed_targets'",
|
|
|
|
|
feed_target_name));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (feed_count > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
feed_count, feed_targets.size(),
|
|
|
|
|
"The number of feed operators should match 'feed_targets'");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of feed operators should match 'feed_targets', but "
|
|
|
|
|
"received feed_count: %zu, required feed_targets.size(): %zu.",
|
|
|
|
|
feed_count, feed_targets.size()));
|
|
|
|
|
|
|
|
|
|
if (!feed_holder_name.empty()) {
|
|
|
|
|
// When feed operator are present, so should be feed_holder.
|
|
|
|
|
auto var = block.FindVar(feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type",
|
|
|
|
|
feed_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Block should already have a '%s' variable", feed_holder_name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var->GetType(), proto::VarType::FEED_MINIBATCH,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"'%s' variable should be 'FEED_MINIBATCH' type, but received "
|
|
|
|
|
"'%s'.",
|
|
|
|
|
feed_holder_name, DataTypeToString(var->GetType())));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -247,29 +263,41 @@ static bool has_fetch_operators(
|
|
|
|
|
if (op->Type() == kFetchOpType) {
|
|
|
|
|
fetch_count++;
|
|
|
|
|
// The output variable's name of fetch_op should be fetch_holder_name.
|
|
|
|
|
PADDLE_ENFORCE_EQ(op->Output("Out")[0], fetch_holder_name,
|
|
|
|
|
"Output of fetch op should be '%s'", fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
op->Output("Out")[0], fetch_holder_name,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Output of fetch op should be '%s', but received '%s'.",
|
|
|
|
|
fetch_holder_name, op->Output("Out")[0]));
|
|
|
|
|
std::string fetch_target_name = op->Input("X")[0];
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
fetch_targets.find(fetch_target_name) != fetch_targets.end(),
|
|
|
|
|
"Fetch operator input name '%s' cannot be found in 'fetch_targets'",
|
|
|
|
|
fetch_target_name);
|
|
|
|
|
PADDLE_ENFORCE_NE(fetch_targets.find(fetch_target_name),
|
|
|
|
|
fetch_targets.end(),
|
|
|
|
|
platform::errors::NotFound(
|
|
|
|
|
"Fetch operator input name '%s' cannot be found in "
|
|
|
|
|
"'fetch_targets'.",
|
|
|
|
|
fetch_target_name));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (fetch_count > 0) {
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
fetch_count, fetch_targets.size(),
|
|
|
|
|
"The number of fetch operators should match 'fetch_targets'");
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"The number of fetch operators should match 'fetch_targets', but "
|
|
|
|
|
"received fetch_count: %zu, required fetch_targets.size(): %zu.",
|
|
|
|
|
fetch_count, fetch_targets.size()));
|
|
|
|
|
|
|
|
|
|
if (!fetch_holder_name.empty()) {
|
|
|
|
|
// When fetch operator are present, so should be fetch_holder.
|
|
|
|
|
auto var = block.FindVar(fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(var, "Block should already have a '%s' variable",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_EQ(var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type",
|
|
|
|
|
fetch_holder_name);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
var,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Block should already have a '%s' variable.", fetch_holder_name));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
var->GetType(), proto::VarType::FETCH_LIST,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"'%s' variable should be 'FETCH_LIST' type, but received '%s'.",
|
|
|
|
|
fetch_holder_name, DataTypeToString(var->GetType())));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -354,7 +382,11 @@ std::unique_ptr<ExecutorPrepareContext> Executor::Prepare(
|
|
|
|
|
const std::vector<std::string>& skip_ref_cnt_vars, bool force_disable_gc) {
|
|
|
|
|
std::unique_ptr<ExecutorPrepareContext> ctx(
|
|
|
|
|
new ExecutorPrepareContext(program, block_id));
|
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size());
|
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<size_t>(block_id), program.Size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input block id = %d, but it should be less than "
|
|
|
|
|
"program.size() which is %d",
|
|
|
|
|
static_cast<size_t>(block_id), program.Size()));
|
|
|
|
|
auto& block = program.Block(block_id);
|
|
|
|
|
for (auto& op_desc : block.AllOps()) {
|
|
|
|
|
ctx->ops_.push_back(OpRegistry::CreateOp(*op_desc));
|
|
|
|
@ -367,14 +399,20 @@ std::vector<std::shared_ptr<ExecutorPrepareContext>> Executor::Prepare(
|
|
|
|
|
const ProgramDesc& program, const std::vector<int>& block_ids,
|
|
|
|
|
const std::vector<std::vector<std::string>>& skip_ref_cnt_vars,
|
|
|
|
|
bool force_disable_gc) {
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
skip_ref_cnt_vars.empty() || skip_ref_cnt_vars.size() == block_ids.size(),
|
|
|
|
|
"skip_ref_cnt_vars should be either empty or equals to block number %d",
|
|
|
|
|
block_ids.size());
|
|
|
|
|
true,
|
|
|
|
|
platform::errors::InvalidArgument("skip_ref_cnt_vars should be either "
|
|
|
|
|
"empty or equals to block number %d",
|
|
|
|
|
block_ids.size()));
|
|
|
|
|
std::vector<std::shared_ptr<ExecutorPrepareContext>> result;
|
|
|
|
|
size_t idx = 0;
|
|
|
|
|
for (auto& bid : block_ids) {
|
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size());
|
|
|
|
|
PADDLE_ENFORCE_LT(static_cast<size_t>(bid), program.Size(),
|
|
|
|
|
platform::errors::InvalidArgument(
|
|
|
|
|
"Input block id = %zu, but it should be less than "
|
|
|
|
|
"program.size() which is %zu",
|
|
|
|
|
static_cast<size_t>(bid), program.Size()));
|
|
|
|
|
auto* ctx = new ExecutorPrepareContext(program, bid);
|
|
|
|
|
auto& block = program.Block(bid);
|
|
|
|
|
for (auto& op_desc : block.AllOps()) {
|
|
|
|
@ -397,7 +435,8 @@ void Executor::RunPartialPreparedContext(ExecutorPrepareContext* ctx,
|
|
|
|
|
bool create_local_scope,
|
|
|
|
|
bool create_vars, bool keep_kids) {
|
|
|
|
|
platform::RecordBlock b(kProgramId);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(scope);
|
|
|
|
|
PADDLE_ENFORCE_NOT_NULL(
|
|
|
|
|
scope, platform::errors::InvalidArgument("Scope shouldn't be null"));
|
|
|
|
|
Scope* local_scope = scope;
|
|
|
|
|
if (create_vars) {
|
|
|
|
|
if (create_local_scope) {
|
|
|
|
@ -470,12 +509,14 @@ void Executor::RunPreparedContext(
|
|
|
|
|
const std::string& fetch_holder_name) {
|
|
|
|
|
auto& global_block = ctx->prog_.Block(ctx->block_id_);
|
|
|
|
|
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
has_feed_operators(global_block, *feed_targets, feed_holder_name),
|
|
|
|
|
"Program in ExecutorPrepareContext should has feed_ops.");
|
|
|
|
|
PADDLE_ENFORCE(
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
has_feed_operators(global_block, *feed_targets, feed_holder_name), true,
|
|
|
|
|
platform::errors::PreconditionNotMet(
|
|
|
|
|
"Program in ExecutorPrepareContext should has feed_ops."));
|
|
|
|
|
PADDLE_ENFORCE_EQ(
|
|
|
|
|
has_fetch_operators(global_block, *fetch_targets, fetch_holder_name),
|
|
|
|
|
"Program in the prepared context should has fetch_ops.");
|
|
|
|
|
true, platform::errors::PreconditionNotMet(
|
|
|
|
|
"Program in the prepared context should has fetch_ops."));
|
|
|
|
|
|
|
|
|
|
// map the data of feed_targets to feed_holder
|
|
|
|
|
for (auto* op : global_block.AllOps()) {
|
|
|
|
|