Minor ngraph fix (#16270)

* take care edge cases test=develop

* use pragma test=develop
revert-16190-refine_parallel_executor
baojun 6 years ago committed by tensor-tang
parent 9195c3bb03
commit 804afc51db

@ -92,12 +92,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
int size = ops->size(); int size = ops->size();
int left = 0; int left = 0;
while (left < size && ops->at(left)->Type() != framework::kFeedOpType) { while (left < size && ops->at(left)->Type() != framework::kFeedOpType &&
ops->at(left)->Type() != framework::kFetchOpType) {
++left; ++left;
} }
if (left == size) {
return intervals;
}
while (left < size && ops->at(left)->Type() == framework::kFeedOpType) { while (left < size && ops->at(left)->Type() == framework::kFeedOpType) {
for (auto& var_name_item : ops->at(left)->Outputs()) { for (auto& var_name_item : ops->at(left)->Outputs()) {
@ -112,10 +110,6 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
while (right < size && ops->at(right)->Type() != framework::kFetchOpType) { while (right < size && ops->at(right)->Type() != framework::kFetchOpType) {
++right; ++right;
} }
if (right == size) {
return intervals;
}
if (left >= right) return intervals;
int index = right; int index = right;
while (index < size && ops->at(index)->Type() == framework::kFetchOpType) { while (index < size && ops->at(index)->Type() == framework::kFetchOpType) {
@ -127,6 +121,10 @@ static std::vector<std::vector<int>> NgraphOpIntervals(
++index; ++index;
} }
if (left == size || ops->at(left)->Type() == framework::kFetchOpType) {
left = 0;
}
// (left, right - 1) represents indices between feed and fetch // (left, right - 1) represents indices between feed and fetch
int pivot = left; int pivot = left;
while (pivot < right) { while (pivot < right) {
@ -234,6 +232,7 @@ NgraphEngine::NgraphEngine(const framework::Scope& scope,
} }
void NgraphEngine::Prepare(const std::vector<int>& interval) { void NgraphEngine::Prepare(const std::vector<int>& interval) {
bool has_fetch = false, is_full = false;
for (auto& var : p_bdesc->AllVars()) { for (auto& var : p_bdesc->AllVars()) {
if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS || if (!(var->GetType() == framework::proto::VarType::SELECTED_ROWS ||
var->GetType() == framework::proto::VarType::LOD_TENSOR || var->GetType() == framework::proto::VarType::LOD_TENSOR ||
@ -264,6 +263,9 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
std::vector<paddle::framework::OpDesc*> ops_desc; std::vector<paddle::framework::OpDesc*> ops_desc;
for (auto op_desc : p_bdesc->AllOps()) { for (auto op_desc : p_bdesc->AllOps()) {
ops_desc.emplace_back(op_desc); ops_desc.emplace_back(op_desc);
if (op_desc->Type() == framework::kFetchOpType) {
has_fetch = true;
}
} }
for (auto op_desc : ops_desc) { for (auto op_desc : ops_desc) {
@ -276,11 +278,11 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
if (interval[0] > 0 && if (interval[0] > 0 &&
ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType && ops_desc.at(interval[0] - 1)->Type() == framework::kFeedOpType &&
interval[1] < static_cast<int>(ops_desc.size()) && interval[1] < static_cast<int>(ops_desc.size()) &&
ops_desc.at(interval.at(1))->Type() == framework::kFetchOpType) { ops_desc.at(interval[1])->Type() == framework::kFetchOpType) {
this->op_state_ = OpState::FULL; is_full = true;
} }
if (this->op_state_ == OpState::FULL) { if (is_full) {
this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN; this->op_state_ = this->is_test_ ? OpState::FULL_TEST : OpState::FULL_TRAIN;
} else { } else {
this->op_state_ = this->op_state_ =
@ -293,7 +295,8 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
framework::OpRegistry::CreateOp(*(ops_desc[idx]))); framework::OpRegistry::CreateOp(*(ops_desc[idx])));
++idx; ++idx;
} }
while (ops_desc.at(idx)->Type() != framework::kFetchOpType) { while (idx < static_cast<int>(ops_desc.size()) &&
ops_desc.at(idx)->Type() != framework::kFetchOpType) {
auto op_desc = ops_desc.at(idx); auto op_desc = ops_desc.at(idx);
for (auto& var_name_item : op_desc->Inputs()) { for (auto& var_name_item : op_desc->Inputs()) {
for (auto& var_name : var_name_item.second) { for (auto& var_name : var_name_item.second) {
@ -303,6 +306,10 @@ void NgraphEngine::Prepare(const std::vector<int>& interval) {
++idx; ++idx;
} }
if (!has_fetch) {
op_state_ = OpState::UNKNOWN;
}
BuildNgIO(ops_desc, interval); BuildNgIO(ops_desc, interval);
} }
@ -378,6 +385,7 @@ void NgraphEngine::BuildNgIO(const std::vector<framework::OpDesc*>& ops_desc,
} }
} }
} }
for (size_t i = 0; i < var_in_.size(); ++i) { for (size_t i = 0; i < var_in_.size(); ++i) {
auto var_name = var_in_[i]; auto var_name = var_in_[i];
if (persistables_.find(var_name) == persistables_.end()) { if (persistables_.find(var_name) == persistables_.end()) {

@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_ #pragma once
#define PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_
#include <memory> #include <memory>
#include <set> #include <set>
#include <string> #include <string>
@ -35,7 +35,6 @@ enum class OpState { /* nGraph support state on ops */
PARTIAL_TRAIN, /* Support partial ops for train */ PARTIAL_TRAIN, /* Support partial ops for train */
FULL_TEST, /* Support full list of ops for test */ FULL_TEST, /* Support full list of ops for test */
PARTIAL_TEST, /* Support partial list of ops for test */ PARTIAL_TEST, /* Support partial list of ops for test */
FULL, /* All ops supported from feed to fetch */
UNKNOWN /* Output all for debug purpose */ UNKNOWN /* Output all for debug purpose */
}; };
@ -119,4 +118,3 @@ class NgraphEngine {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
#endif // PADDLE_FLUID_OPERATORS_NGRAPH_NGRAPH_ENGINE_H_

Loading…
Cancel
Save