|
|
|
|
@ -92,13 +92,20 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
|
|
|
|
|
std::vector<std::vector<int>> intervals;
|
|
|
|
|
|
|
|
|
|
int size = ops->size();
|
|
|
|
|
int left = 0;
|
|
|
|
|
int left = 0, feed_idx = -1;
|
|
|
|
|
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
|
|
|
|
|
ops->at(left)->Type() != "read" &&
|
|
|
|
|
ops->at(left)->Type() != framework::kFetchOpType) {
|
|
|
|
|
++left;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (left < size) {
|
|
|
|
|
auto op_type = ops->at(left)->Type();
|
|
|
|
|
if (op_type == framework::kFeedOpType || op_type == "read") {
|
|
|
|
|
feed_idx = left;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
while (left < size && (ops->at(left)->Type() == framework::kFeedOpType ||
|
|
|
|
|
ops->at(left)->Type() == "read")) {
|
|
|
|
|
for (auto& var_name_item : ops->at(left)->Outputs()) {
|
|
|
|
|
@ -141,7 +148,9 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
|
|
|
|
|
++end;
|
|
|
|
|
}
|
|
|
|
|
std::vector<int> interval = {start, end};
|
|
|
|
|
intervals.emplace_back(interval);
|
|
|
|
|
if (feed_idx != -1 && start > feed_idx) {
|
|
|
|
|
intervals.emplace_back(interval);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // end while
|
|
|
|
|
return intervals;
|
|
|
|
|
@ -252,7 +261,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
|
|
|
|
|
NgraphEngine::p_bdesc = &block_desc;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool has_fetch = false, is_full = false;
|
|
|
|
|
for (auto& var : p_bdesc->AllVars()) {
|
|
|
|
|
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
|
|
|
|
|
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
|
|
|
|
|
@ -283,33 +291,12 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
|
|
|
|
|
std::vector<paddle::framework::OpDesc*> ops_desc;
|
|
|
|
|
for (auto op_desc : p_bdesc->AllOps()) {
|
|
|
|
|
ops_desc.emplace_back(op_desc);
|
|
|
|
|
if (op_desc->Type() == framework::kFetchOpType) {
|
|
|
|
|
has_fetch = true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (auto op_desc : ops_desc) {
|
|
|
|
|
if (op_desc->Type().find("_grad") != std::string::npos) {
|
|
|
|
|
is_training = true;
|
|
|
|
|
this->is_test_ = false;
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (interval[0] > 0 &&
|
|
|
|
|
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
|
|
|
|
|
interval[1] < static_cast<int>(ops_desc.size()) &&
|
|
|
|
|
ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
|
|
|
|
|
is_full = true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (is_full) {
|
|
|
|
|
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
|
|
|
|
|
} else {
|
|
|
|
|
this->op_state_ =
|
|
|
|
|
this->is_test_ ? OpState::PARTIAL_TEST : OpState::PARTIAL_TRAIN;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
int idx = interval[0];
|
|
|
|
|
while (idx < interval[1]) {
|
|
|
|
|
this->fused_ops_.emplace_back(
|
|
|
|
|
@ -327,10 +314,6 @@ void NgraphEngine::Prepare(const framework::ExecutionContext& ctx) {
|
|
|
|
|
++idx;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!has_fetch) {
|
|
|
|
|
op_state_ = OpState::UNKNOWN;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (var_in_.empty() && var_out_.empty()) {
|
|
|
|
|
BuildNgIO(ops_desc, interval);
|
|
|
|
|
}
|
|
|
|
|
@ -380,37 +363,19 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
|
|
|
|
|
"op %s has more than 1 output - Not handling yet",
|
|
|
|
|
op->Type());
|
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
|
switch (this->op_state_) {
|
|
|
|
|
case OpState::PARTIAL_TEST:
|
|
|
|
|
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
|
|
|
|
|
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case OpState::FULL_TEST:
|
|
|
|
|
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case OpState::PARTIAL_TRAIN:
|
|
|
|
|
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end() ||
|
|
|
|
|
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
|
|
|
|
|
persistables_.find(var_name) != persistables_.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
case OpState::FULL_TRAIN:
|
|
|
|
|
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end() ||
|
|
|
|
|
persistables_.find(var_name) != persistables_.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
break;
|
|
|
|
|
default:
|
|
|
|
|
if (this->is_test_) {
|
|
|
|
|
if (post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
|
|
|
|
|
find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if (find(fetch_vars.begin(), fetch_vars.end(), var_name) !=
|
|
|
|
|
fetch_vars.end() ||
|
|
|
|
|
post_op_inputs_.find(var_name) != post_op_inputs_.end() ||
|
|
|
|
|
persistables_.find(var_name) != persistables_.end()) {
|
|
|
|
|
this->var_out_.emplace_back(var_name);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|