|
|
|
@ -16,6 +16,7 @@
|
|
|
|
|
#include "backend/session/executor.h"
|
|
|
|
|
#include "runtime/device/kernel_runtime_manager.h"
|
|
|
|
|
#include "backend/session/executor_manager.h"
|
|
|
|
|
#include "utils/comm_manager.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace session {
|
|
|
|
@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs,
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
BaseRef TransformBaseRefListToTuple(const BaseRef &base_ref) {
|
|
|
|
|
if (utils::isa<VectorRef>(base_ref)) {
|
|
|
|
|
auto ref_list = utils::cast<VectorRef>(base_ref);
|
|
|
|
|
py::tuple output_tensors(ref_list.size());
|
|
|
|
|
for (size_t i = 0; i < ref_list.size(); ++i) {
|
|
|
|
|
auto output = TransformBaseRefListToTuple(ref_list[i]); // use pyObjectRef
|
|
|
|
|
if (utils::isa<tensor::TensorPtr>(output)) {
|
|
|
|
|
auto tensor_ptr = utils::cast<tensor::TensorPtr>(output);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(tensor_ptr);
|
|
|
|
|
output_tensors[i] = tensor_ptr;
|
|
|
|
|
} else if (utils::isa<PyObjectRef>(output)) {
|
|
|
|
|
py::object obj = utils::cast<PyObjectRef>(output).object_;
|
|
|
|
|
py::tuple tensor_tuple = py::cast<py::tuple>(obj);
|
|
|
|
|
output_tensors[i] = tensor_tuple;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return output_tensors; // turn tuple to py::object and store in PyObjectRef
|
|
|
|
|
} else if (utils::isa<tensor::TensorPtr>(base_ref)) {
|
|
|
|
|
return base_ref;
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output is not a base ref list or a tensor!";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
} // namespace
|
|
|
|
|
void CompileNodesTask::Run() {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(session_);
|
|
|
|
@ -104,6 +79,10 @@ void RunOpTask::Run() {
|
|
|
|
|
session_->RunOp(*op_run_info_, graph_info_, input_tensors_, &outputs_);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CreateCommGroupTask::Run() { result_ = CommManager::GetInstance().CreateGroupSync(group_name_, ranks_); }
|
|
|
|
|
|
|
|
|
|
void DestroyCommGroupTask::Run() { result_ = CommManager::GetInstance().DestroyGroup(group_name_); }
|
|
|
|
|
|
|
|
|
|
Executor::Executor(const std::string &device_name, uint32_t device_id) {
|
|
|
|
|
device_name_ = device_name;
|
|
|
|
|
device_id_ = device_id;
|
|
|
|
@ -141,22 +120,8 @@ void Executor::WorkerLoop() {
|
|
|
|
|
} catch (const std::exception &e) {
|
|
|
|
|
exception_ptr_ = std::current_exception();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto task_type = task->type_;
|
|
|
|
|
task = nullptr;
|
|
|
|
|
if (task_type == kCompileNodes) {
|
|
|
|
|
compile_cond_var_.notify_all();
|
|
|
|
|
} else if (task_type == kCompileGraph) {
|
|
|
|
|
compile_cond_var_.notify_all();
|
|
|
|
|
} else if (task_type == kBuildGraph) {
|
|
|
|
|
build_cond_var_.notify_all();
|
|
|
|
|
} else if (task_type == kRunGraph) {
|
|
|
|
|
run_cond_var_.notify_all();
|
|
|
|
|
} else if (task_type == kBuildOp) {
|
|
|
|
|
build_op_cond_var_.notify_all();
|
|
|
|
|
} else if (task_type == kRunOp) {
|
|
|
|
|
run_op_cond_var_.notify_all();
|
|
|
|
|
}
|
|
|
|
|
sync_cond_var_.notify_all();
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
|
|
|
|
|
task->output_nodes_ = outputs;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
compile_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
return task->graph_id_;
|
|
|
|
|
}
|
|
|
|
@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
|
|
|
|
|
task->func_graph_ = func_graph;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
compile_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
return task->graph_id_;
|
|
|
|
|
}
|
|
|
|
@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
|
|
|
|
|
task->graph_id_ = graphId;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
build_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
py::gil_scoped_release release;
|
|
|
|
|
run_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
|
|
|
|
|
task->tensors_mask_ = tensors_mask;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
build_op_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors) {
|
|
|
|
|
void Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, const GraphInfo &graph_info,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors, VectorRef *outputs) {
|
|
|
|
|
CheckException();
|
|
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
|
|
|
|
auto task = std::make_shared<RunOpTask>();
|
|
|
|
@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
|
|
|
|
|
task->input_tensors_ = input_tensors;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
run_op_cond_var_.wait(lock);
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
CheckException();
|
|
|
|
|
*outputs = task->outputs_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Trans output to tuple
|
|
|
|
|
auto output_tensors = TransformBaseRefListToTuple(task->outputs_);
|
|
|
|
|
if (!utils::isa<PyObjectRef>(output_tensors) ||
|
|
|
|
|
!py::isinstance<py::tuple>(utils::cast<PyObjectRef>(output_tensors).object_)) {
|
|
|
|
|
MS_EXCEPTION(NotSupportError) << "The output tensors should be a tuple !";
|
|
|
|
|
}
|
|
|
|
|
py::object tuple_obj = utils::cast<PyObjectRef>(output_tensors).object_;
|
|
|
|
|
py::tuple tuple_tensors = py::cast<py::tuple>(tuple_obj);
|
|
|
|
|
return tuple_tensors;
|
|
|
|
|
bool Executor::CreateCommGroup(const std::string &group_name, std::vector<uint32_t> ranks) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
|
|
|
|
auto task = std::make_shared<CreateCommGroupTask>();
|
|
|
|
|
task->group_name_ = group_name;
|
|
|
|
|
task->ranks_ = ranks;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
return task->result_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool Executor::DestroyCommGroup(const std::string &group_name) {
|
|
|
|
|
std::unique_lock<std::mutex> lock(task_mutex_);
|
|
|
|
|
auto task = std::make_shared<DestroyCommGroupTask>();
|
|
|
|
|
task->group_name_ = group_name;
|
|
|
|
|
ready_tasks_.push(task);
|
|
|
|
|
task_cond_var_.notify_all();
|
|
|
|
|
sync_cond_var_.wait(lock);
|
|
|
|
|
return task->result_;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void Executor::StopWorker() {
|
|
|
|
|