!11167 fix executor codex

From: @kisnwang
Reviewed-by: @zhoufeng54,@chujinjin
Signed-off-by: @chujinjin
pull/11167/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit b8a3a539bc

@ -232,9 +232,7 @@ void Executor::OnException() {
} }
{ {
std::lock_guard<std::mutex> lock(pending_task_mutex_); std::lock_guard<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end(); ++iter) { std::copy(pending_tasks_.begin(), pending_tasks_.end(), std::back_inserter(new_done_tasks));
new_done_tasks.emplace_back(*iter);
}
pending_tasks_.clear(); pending_tasks_.clear();
} }
{ {
@ -249,7 +247,7 @@ void Executor::OnRunGraphFinished() {
for (auto &task : new_ready_tasks) { for (auto &task : new_ready_tasks) {
ready_tasks_.push(task); ready_tasks_.push(task);
} }
if (new_ready_tasks.size() > 0) { if (!new_ready_tasks.empty()) {
task_cond_var_.notify_all(); task_cond_var_.notify_all();
} }
reenter_cond_var_.notify_all(); reenter_cond_var_.notify_all();
@ -288,15 +286,9 @@ void Executor::RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_r
std::unique_lock<std::mutex> lock(task_mutex_); std::unique_lock<std::mutex> lock(task_mutex_);
if (long_run) { if (long_run) {
mindspore::ScopedLongRunning long_running; mindspore::ScopedLongRunning long_running;
sync_cond_var_.wait(lock, [this] { sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
bool finished = sync_run_task_finished_;
return finished;
});
} else { } else {
sync_cond_var_.wait(lock, [this] { sync_cond_var_.wait(lock, [this] { return sync_run_task_finished_; });
bool finished = sync_run_task_finished_;
return finished;
});
} }
} }
ClearDoneTasks(); ClearDoneTasks();

@ -42,7 +42,6 @@ enum TaskType {
kCompileNodes, kCompileNodes,
kCompileGraph, kCompileGraph,
kBuildGraph, kBuildGraph,
kBuildOp,
kRunGraph, kRunGraph,
kRunOp, kRunOp,
kCreateCommGroup, kCreateCommGroup,
@ -117,7 +116,7 @@ class RunOpTask : public Task {
void Run() override; void Run() override;
OpRunInfo *op_run_info_{nullptr}; OpRunInfo *op_run_info_{nullptr};
GraphInfo graph_info_; GraphInfo graph_info_;
std::vector<tensor::TensorPtr> *input_tensors_; std::vector<tensor::TensorPtr> *input_tensors_{nullptr};
VectorRef outputs_; VectorRef outputs_;
std::vector<int64_t> tensors_mask_; std::vector<int64_t> tensors_mask_;
}; };
@ -173,12 +172,9 @@ class Executor {
private: private:
void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false); void RunTask(const std::shared_ptr<Task> &task, bool sync, bool long_run = false);
void UpdateOutputTensors(VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node);
std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks(); std::vector<std::shared_ptr<RunGraphTask>> GetNewReadyTasks();
bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task); bool IsTaskReady(const std::shared_ptr<RunGraphTask> &task);
void WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task); void WaitTaskGraphAvailable(const SessionPtr &session, const std::shared_ptr<RunGraphTask> &task);
void CheckException();
void OnWorkerExit(); void OnWorkerExit();
void OnRunGraphFinished(); void OnRunGraphFinished();
void OnException(); void OnException();
@ -197,7 +193,7 @@ class Executor {
std::list<std::shared_ptr<RunGraphTask>> pending_tasks_; std::list<std::shared_ptr<RunGraphTask>> pending_tasks_;
std::vector<std::shared_ptr<Task>> done_tasks_; std::vector<std::shared_ptr<Task>> done_tasks_;
std::shared_ptr<std::thread> worker_; std::shared_ptr<std::thread> worker_;
std::atomic_bool sync_run_task_finished_{false}; bool sync_run_task_finished_{false};
}; };
} // namespace session } // namespace session
} // namespace mindspore } // namespace mindspore

@ -98,7 +98,7 @@ void AddControlEdge(const FuncGraphPtr &graph, const AnfNodePtr &node,
auto depend_node = input_cnode->input(kControlDependBehindIndex); auto depend_node = input_cnode->input(kControlDependBehindIndex);
MS_EXCEPTION_IF_NULL(prior_node); MS_EXCEPTION_IF_NULL(prior_node);
MS_EXCEPTION_IF_NULL(depend_node); MS_EXCEPTION_IF_NULL(depend_node);
PrimitivePtr prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0)); auto prim_ptr = GetValueNode<PrimitivePtr>(input_cnode->input(0));
MS_EXCEPTION_IF_NULL(prim_ptr); MS_EXCEPTION_IF_NULL(prim_ptr);
ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode"); ValuePtr mode_ptr = prim_ptr->GetAttr("depend_mode");
int64_t depend_mode = 0; int64_t depend_mode = 0;
@ -214,7 +214,7 @@ std::vector<AnfNodePtr> SplitSort(const FuncGraphPtr &graph, const std::string &
std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges; std::map<AnfNodePtr, std::vector<AnfNodePtr>> control_edges;
CalcNodeRefCount(graph, &nodes_ref, &control_edges); CalcNodeRefCount(graph, &nodes_ref, &control_edges);
std::string handle_target = default_target; std::string handle_target = default_target;
std::string next_target = ""; std::string next_target;
to_visit.push(graph->get_return()); to_visit.push(graph->get_return());
while (!to_visit.empty() || !next_to_visit.empty()) { while (!to_visit.empty() || !next_to_visit.empty()) {
if (to_visit.empty()) { if (to_visit.empty()) {
@ -590,10 +590,10 @@ struct SplitDynamicNodesHelper {
if (pre_nodes.size() < merge_node_threshold) { if (pre_nodes.size() < merge_node_threshold) {
AddSegment(pre_nodes, segments, node_to_segment); AddSegment(pre_nodes, segments, node_to_segment);
} else { } else {
if (pre_common_nodes.size() > 0) { if (!pre_common_nodes.empty()) {
AddSegment(pre_common_nodes, segments, node_to_segment); AddSegment(pre_common_nodes, segments, node_to_segment);
} }
if (pre_dynamic_nodes.size() > 0) { if (!pre_dynamic_nodes.empty()) {
AddSegment(pre_dynamic_nodes, segments, node_to_segment); AddSegment(pre_dynamic_nodes, segments, node_to_segment);
} }
} }
@ -656,7 +656,7 @@ void SplitDynamicNodeSegment(const std::vector<AnfNodePtr> &segment_nodes, std::
void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments, void NodesToSegments(const std::vector<AnfNodePtr> &segment_nodes, std::vector<GraphSegmentPtr> *segments,
std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) { std::map<AnfNodePtr, GraphSegmentPtr> *node_to_segment) {
if (segment_nodes.size() == 0) { if (segment_nodes.empty()) {
return; return;
} }
auto segment_target = GetCNodeTarget(segment_nodes[0]); auto segment_target = GetCNodeTarget(segment_nodes[0]);
@ -702,7 +702,7 @@ bool GraphPartition::IsCut(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(fn)) { if (!IsValueNode<Primitive>(fn)) {
return true; return true;
} }
PrimitivePtr node_prim = GetValueNode<PrimitivePtr>(fn); auto node_prim = GetValueNode<PrimitivePtr>(fn);
for (auto &prim : cut_list_) { for (auto &prim : cut_list_) {
MS_EXCEPTION_IF_NULL(prim); MS_EXCEPTION_IF_NULL(prim);
if (prim->name() == node_prim->name()) { if (prim->name() == node_prim->name()) {

Loading…
Cancel
Save