You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/backend/session/executor.cc

391 lines
12 KiB

/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "backend/session/executor.h"
#include <algorithm>
#include <exception>
#include "backend/session/executor_manager.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "utils/comm_manager.h"
#include "utils/scoped_long_running.h"
namespace mindspore {
namespace session {
namespace {
void UpdateOutputTensors(const VectorRef *outputs,
const std::map<tensor::TensorPtr, session::KernelWithIndex> &tensor_to_node) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
UpdateOutputTensors(&vector_ref, tensor_to_node);
} else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor);
auto iter = tensor_to_node.find(tensor);
if (iter != tensor_to_node.end()) {
auto &node = iter->second.first;
auto &output_index = iter->second.second;
auto address = AnfAlgo::GetMutableOutputAddr(node, output_index);
tensor->set_device_address(address);
if (AnfAlgo::IsDynamicShape(node)) {
auto updated_shape = AnfAlgo::GetOutputInferShape(node, output_index);
ShapeVector int_shape;
std::transform(updated_shape.begin(), updated_shape.end(), std::back_inserter(int_shape), SizeToInt);
tensor->set_shape(int_shape);
}
}
if (tensor->NeedSyncDeviceToHostImmediately()) {
tensor->data_sync(false);
tensor->set_device_address(nullptr);
tensor->set_sync_status(kNeedSyncHostToDevice);
}
}
}
}
void NotifyOutputTensors(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
NotifyOutputTensors(&vector_ref);
} else if (utils::isa<tensor::TensorPtr>(item)) {
auto tensor = utils::cast<tensor::TensorPtr>(item);
MS_EXCEPTION_IF_NULL(tensor);
tensor->SetNeedWait(false);
}
}
}
bool TensorInVector(const VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(outputs);
for (auto item : *outputs) {
if (utils::isa<VectorRefPtr>(item)) {
auto vector_ref = utils::cast<VectorRef>(item);
if (TensorInVector(&vector_ref)) {
return true;
}
} else if (utils::isa<tensor::TensorPtr>(item)) {
return true;
}
}
return false;
}
} // namespace
void CompileNodesTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
MS_EXCEPTION_IF_NULL(segment_);
graph_id_ = session_->CompileGraphImpl(segment_->nodes_, output_nodes_);
}
void CompileGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
graph_id_ = session_->CompileGraphImpl(NOT_NULL(func_graph_));
}
void BuildGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->BuildGraphImpl(graph_id_);
}
void RunGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
try {
MS_LOG(INFO) << "Start run graph " << graph_id_;
auto graph = session_->GetGraph(graph_id_);
MS_EXCEPTION_IF_NULL(graph);
graph->ResetGraphRunningStatus();
session_->RunGraphImpl(graph_id_, input_tensors_, &outputs_);
graph->OnRunGraphFinished();
UpdateOutputTensors(&outputs_, tensor_to_node_);
MS_LOG(INFO) << "End run graph " << graph_id_;
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
}
for (auto &tensor : input_need_lock_tensors_) {
tensor->SetNeedWait(false);
}
NotifyOutputTensors(&outputs_);
ExecutorManager::Instance().OnRunGraphFinished();
}
void RunOpTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->RunOpImpl(*op_run_info_, graph_info_, input_tensors_, &outputs_, tensors_mask_);
}
void RunOpsInGraphTask::Run() {
MS_EXCEPTION_IF_NULL(session_);
session_->RunOpsInGraphImpl(graph_id_, input_tensors_, &outputs_);
}
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
Executor::Executor(const std::string &device_name, uint32_t device_id) {
device_name_ = device_name;
device_id_ = device_id;
worker_ = std::make_shared<std::thread>(&Executor::WorkerLoop, this);
}
Executor::~Executor() { WorkerJoin(); }
void Executor::WorkerJoin() {
// Avoid worker thread join itself which will cause deadlock
if (worker_->joinable() && worker_->get_id() != std::this_thread::get_id()) {
{
std::unique_lock<std::mutex> lock(task_mutex_);
auto task = std::make_shared<ExitTask>();
ready_tasks_.push(task);
task_cond_var_.notify_all();
}
worker_->join();
}
}
void Executor::WorkerLoop() {
while (true) {
std::shared_ptr<Task> task;
{
std::unique_lock<std::mutex> lock(task_mutex_);
task_cond_var_.wait(lock, [this] { return !ready_tasks_.empty(); });
task = ready_tasks_.front();
ready_tasks_.pop();
}
if (task->type_ == kExit) {
OnWorkerExit();
return;
}
try {
task->Run();
} catch (const std::exception &e) {
MsException::GetInstance().SetException();
}
{
std::unique_lock<std::mutex> lock(task_mutex_);
done_tasks_.emplace_back(task);
}
if (task->type_ != kRunGraph || task->sync_run_) {
sync_cond_var_.notify_all();
}
}
}
std::vector<std::shared_ptr<RunGraphTask>> Executor::GetNewReadyTasks() {
std::vector<std::shared_ptr<RunGraphTask>> new_ready_tasks;
std::unique_lock<std::mutex> lock(pending_task_mutex_);
for (auto iter = pending_tasks_.begin(); iter != pending_tasks_.end();) {
auto task = *iter;
if (IsTaskReady(task)) {
new_ready_tasks.emplace_back(task);
pending_tasks_.erase(iter++);
} else {
iter++;
}
}
return new_ready_tasks;
}
void Executor::OnRunGraphFinished() {
auto new_ready_tasks = GetNewReadyTasks();
std::unique_lock<std::mutex> lock(task_mutex_);
for (auto &task : new_ready_tasks) {
ready_tasks_.push(task);
}
if (new_ready_tasks.size() > 0) {
task_cond_var_.notify_all();
}
reenter_cond_var_.notify_all();
}
bool Executor::IsTaskReady(const std::shared_ptr<RunGraphTask> &task) {
MS_EXCEPTION_IF_NULL(task);
for (auto &input : task->input_need_wait_tensors_) {
MS_EXCEPTION_IF_NULL(input);
if (input->NeedWait()) {
return false;
}
}
auto session = task->session_;
MS_EXCEPTION_IF_NULL(session);
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
return graph->IsPreGraphFinished();
}
return true;
}
void Executor::SyncRunTask(const std::shared_ptr<Task> &task) {
std::unique_lock<std::mutex> lock(task_mutex_);
ready_tasks_.push(task);
done_tasks_.clear();
task_cond_var_.notify_all();
sync_cond_var_.wait(lock);
MsException::GetInstance().CheckException();
}
GraphId Executor::CompileGraph(const SessionPtr &session, const GraphSegmentPtr &segment,
const AnfNodePtrList &outputs) {
auto task = std::make_shared<CompileNodesTask>();
task->session_ = session;
task->segment_ = segment;
task->output_nodes_ = outputs;
SyncRunTask(task);
return task->graph_id_;
}
GraphId Executor::CompileGraph(const SessionPtr &session, NotNull<FuncGraphPtr> func_graph) {
auto task = std::make_shared<CompileGraphTask>();
task->session_ = session;
task->func_graph_ = func_graph;
SyncRunTask(task);
return task->graph_id_;
}
void Executor::BuildGraph(const SessionPtr &session, GraphId graphId) {
auto task = std::make_shared<BuildGraphTask>();
task->session_ = session;
task->graph_id_ = graphId;
SyncRunTask(task);
}
void Executor::RunGraph(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);
MS_EXCEPTION_IF_NULL(outputs);
auto task = std::make_shared<RunGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
task->outputs_ = *outputs;
task->sync_run_ = true;
mindspore::ScopedLongRunning long_running;
SyncRunTask(task);
}
void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);
MS_EXCEPTION_IF_NULL(outputs);
auto task = std::make_shared<RunGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
task->input_need_lock_tensors_ = session->GetInputNeedLockTensors(graph_id, inputs);
for (auto &tensor : inputs) {
if (tensor->NeedWait()) {
if (tensor->IsGraphOutput()) {
task->input_need_wait_tensors_.emplace_back(tensor);
} else {
mindspore::ScopedLongRunning long_running;
tensor->Wait();
}
}
}
for (auto &tensor : task->input_need_lock_tensors_) {
tensor->SetNeedWait(true);
}
session->CreateOutputTensors(graph_id, inputs, outputs, &task->tensor_to_node_);
// maintain a copy of output vector
task->outputs_ = *outputs;
// sync run graph without output tensor(int dataset graph)
if (!TensorInVector(outputs)) {
task->sync_run_ = true;
mindspore::ScopedLongRunning long_running;
SyncRunTask(task);
return;
}
auto graph = session->GetGraph(task->graph_id_);
if (graph != nullptr) {
if (!graph->IsPostGraphFinished()) {
mindspore::ScopedLongRunning long_running;
std::unique_lock<std::mutex> lock(reenter_mutex_);
reenter_cond_var_.wait(lock, [graph] { return graph->IsPostGraphFinished(); });
}
}
bool ready = IsTaskReady(task);
if (!ready) {
std::unique_lock<std::mutex> lock(pending_task_mutex_);
pending_tasks_.push_back(task);
return;
}
std::unique_lock<std::mutex> lock(task_mutex_);
ready_tasks_.push(task);
done_tasks_.clear();
task_cond_var_.notify_all();
}
void Executor::RunOp(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
std::vector<tensor::TensorPtr> *input_tensors, VectorRef *outputs,
const std::vector<int64_t> &tensors_mask) {
auto task = std::make_shared<RunOpTask>();
task->session_ = session;
task->op_run_info_ = op_run_info;
task->graph_info_ = graph_info;
task->input_tensors_ = input_tensors;
task->tensors_mask_ = tensors_mask;
for (auto &tensor : *input_tensors) {
if (tensor->NeedWait()) {
tensor->Wait();
}
}
mindspore::ScopedLongRunning long_running;
SyncRunTask(task);
*outputs = task->outputs_;
}
void Executor::RunOpsInGraph(const SessionPtr &session, const GraphId &graph_id,
const std::vector<tensor::TensorPtr> &inputs, VectorRef *outputs) {
MS_EXCEPTION_IF_NULL(session);
MS_EXCEPTION_IF_NULL(outputs);
auto task = std::make_shared<RunOpsInGraphTask>();
task->session_ = session;
task->graph_id_ = graph_id;
task->input_tensors_ = inputs;
SyncRunTask(task);
*outputs = task->outputs_;
}
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
auto task = std::make_shared<CreateCommGroupTask>();
task->group_name_ = group_name;
task->ranks_ = ranks;
SyncRunTask(task);
return task->result_;
}
bool Executor::DestroyCommGroup(const std::string &group_name) {
auto task = std::make_shared<DestroyCommGroupTask>();
task->group_name_ = group_name;
SyncRunTask(task);
return task->result_;
}
void Executor::OnWorkerExit() {
if (device_name_ == kAscendDevice) {
device::KernelRuntimeManager::Instance().ReleaseKernelRuntime(kAscendDevice, device_id_);
}
}
} // namespace session
} // namespace mindspore