|
|
@ -92,12 +92,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
|
|
|
|
|
|
|
|
|
|
|
|
int size = ops->size();
|
|
|
|
int size = ops->size();
|
|
|
|
int left = 0;
|
|
|
|
int left = 0;
|
|
|
|
while (left < size && ops->at(left)->Type() != framework::kFeedOpType) {
|
|
|
|
while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
|
|
|
|
|
|
|
|
ops->at(left)->Type() != framework::kFetchOpType) {
|
|
|
|
++left;
|
|
|
|
++left;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (left == size) {
|
|
|
|
|
|
|
|
return intervals;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
|
|
|
|
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
|
|
|
|
for (auto& var_name_item : ops->at(left)->Outputs()) {
|
|
|
|
for (auto& var_name_item : ops->at(left)->Outputs()) {
|
|
|
@ -112,10 +110,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
|
|
|
|
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
|
|
|
|
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
|
|
|
|
++right;
|
|
|
|
++right;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
if (right == size) {
|
|
|
|
|
|
|
|
return intervals;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
if (left >= right) return intervals;
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int index = right;
|
|
|
|
int index = right;
|
|
|
|
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
|
|
|
|
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
|
|
|
@ -127,6 +121,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
|
|
|
|
++index;
|
|
|
|
++index;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (left == size || ops->at(left)->Type() == framework::kFetchOpType) {
|
|
|
|
|
|
|
|
left = 0;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// (left, right - 1) represents indices between feed and fetch
|
|
|
|
// (left, right - 1) represents indices between feed and fetch
|
|
|
|
int pivot = left;
|
|
|
|
int pivot = left;
|
|
|
|
while (pivot < right) {
|
|
|
|
while (pivot < right) {
|
|
|
@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
|
|
|
|
bool has_fetch = false, is_full = false;
|
|
|
|
for (auto& var : p_bdesc->AllVars()) {
|
|
|
|
for (auto& var : p_bdesc->AllVars()) {
|
|
|
|
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
|
|
|
|
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
|
|
|
|
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
|
|
|
|
var->GetType() == framework::proto::VarType::LOD_TENSOR ||
|
|
|
@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
std::vector<paddle::framework::OpDesc*> ops_desc;
|
|
|
|
std::vector<paddle::framework::OpDesc*> ops_desc;
|
|
|
|
for (auto op_desc : p_bdesc->AllOps()) {
|
|
|
|
for (auto op_desc : p_bdesc->AllOps()) {
|
|
|
|
ops_desc.emplace_back(op_desc);
|
|
|
|
ops_desc.emplace_back(op_desc);
|
|
|
|
|
|
|
|
if (op_desc->Type() == framework::kFetchOpType) {
|
|
|
|
|
|
|
|
has_fetch = true;
|
|
|
|
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (auto op_desc : ops_desc) {
|
|
|
|
for (auto op_desc : ops_desc) {
|
|
|
@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
if (interval[0] > 0 &&
|
|
|
|
if (interval[0] > 0 &&
|
|
|
|
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
|
|
|
|
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
|
|
|
|
interval[1] < static_cast<int>(ops_desc.size()) &&
|
|
|
|
interval[1] < static_cast<int>(ops_desc.size()) &&
|
|
|
|
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) {
|
|
|
|
ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
|
|
|
|
this->op_state_ = OpState::FULL;
|
|
|
|
is_full = true;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
if (this->op_state_ == OpState::FULL) {
|
|
|
|
if (is_full) {
|
|
|
|
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
|
|
|
|
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
this->op_state_ =
|
|
|
|
this->op_state_ =
|
|
|
@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
|
|
|
|
framework::OpRegistry::CreateOp(*(ops_desc[idx])));
|
|
|
|
++idx;
|
|
|
|
++idx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) {
|
|
|
|
while (idx < static_cast<int>(ops_desc.size()) &&
|
|
|
|
|
|
|
|
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
|
|
|
|
auto op_desc = ops_desc.at(idx);
|
|
|
|
auto op_desc = ops_desc.at(idx);
|
|
|
|
for (auto& var_name_item : op_desc->Inputs()) {
|
|
|
|
for (auto& var_name_item : op_desc->Inputs()) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
|
for (auto& var_name : var_name_item.second) {
|
|
|
@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
|
|
|
|
++idx;
|
|
|
|
++idx;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (!has_fetch) {
|
|
|
|
|
|
|
|
op_state_ = OpState::UNKNOWN;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
BuildNgIO(ops_desc, interval);
|
|
|
|
BuildNgIO(ops_desc, interval);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < var_in_.size(); ++i) {
|
|
|
|
for (size_t i = 0; i < var_in_.size(); ++i) {
|
|
|
|
auto var_name = var_in_[i];
|
|
|
|
auto var_name = var_in_[i];
|
|
|
|
if (persistables_.find(var_name) == persistables_.end()) {
|
|
|
|
if (persistables_.find(var_name) == persistables_.end()) {
|
|
|
|