Use the unified Execute function to run Graph or Single Op Graph.

pull/5832/head
Zhang Qinghua 5 years ago
parent 77dd91a646
commit c0070d3d49

@ -318,7 +318,7 @@ void AscendSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::
#endif
{
// run task on device
Execute(kernel_graph);
Execute(kernel_graph, true);
}
// summary
Summary(kernel_graph.get());
@ -348,17 +348,6 @@ void AscendSession::RunOpHardwareOptimize(const std::shared_ptr<session::KernelG
MS_LOG(INFO) << "Finish";
}
void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->LaunchKernel(kernel_graph.get());
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Run task error!";
}
MS_LOG(INFO) << "Finish!";
}
bool AscendSession::GraphCacheExist(const GraphInfo &graph_info) const {
return run_op_graphs_.find(graph_info) != run_op_graphs_.end();
}
@ -398,7 +387,7 @@ void AscendSession::RunOp(const OpRunInfo &op_run_info, const GraphInfo &graph_i
// load input data to device
LoadInputData(graph, input_tensors);
// run op
RunOpExecTask(graph);
Execute(graph, false);
// get output
if (op_run_info.value != nullptr) {
std::vector<tensor::TensorPtr> pre_output_tensors;
@ -552,21 +541,30 @@ void AscendSession::RunOpMemoryClear(const KernelGraph *kernel_graph) const {
void AscendSession::Load(const std::shared_ptr<KernelGraph> &kernel_graph) const {
MS_LOG(INFO) << "Start!";
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
(void)device::KernelAdjust::GetInstance().StepLoadCtrlInputs(kernel_graph);
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Load(kernel_graph.get());
bool ret_ok = runtime_instance->Load(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "Load task error!";
}
MS_LOG(INFO) << "Finish!";
}
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const {
void AscendSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const {
MS_LOG(INFO) << "Start!";
bool is_task_sink = false;
if (is_task) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
}
auto runtime_instance = device::KernelRuntimeManager::Instance().GetKernelRuntime(kAscendDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
bool ret_ok = runtime_instance->Run(kernel_graph.get());
bool ret_ok = runtime_instance->Run(kernel_graph.get(), is_task_sink);
if (!ret_ok) {
MS_LOG(EXCEPTION) << "run task error!";
}

@ -13,8 +13,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#define MINDSPORE_CCSRC_BACKEND_SESSION_ASCEND_SESSION_H
#include <unordered_map>
#include <string>
#include <memory>
@ -82,13 +84,12 @@ class AscendSession : public SessionBasic {
KernelGraph *kernel_graph) const;
void RunOpMemoryClear(const KernelGraph *kernel_graph) const;
void Load(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void Execute(const std::shared_ptr<KernelGraph> &kernel_graph, bool is_task) const;
void Dump(const std::shared_ptr<KernelGraph> &kernel_graph) const;
void DumpAllGraphs(const std::vector<KernelGraphPtr> &all_graphs);
void LoadTensor(const std::shared_ptr<KernelGraph> &kernel_graph) const;
// below functions are used for run op
void RunOpHardwareOptimize(const std::shared_ptr<session::KernelGraph> &kernel_graph) const;
void RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_graph) const;
static void BackendOptimization(const std::vector<KernelGraphPtr> &all_graphs);
static void LinkChildGraphs(NotNull<KernelGraphPtr> graph);

@ -118,7 +118,7 @@ void CPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
debugger_->PreExecute(kernel_graph);
}
#endif
bool ret = runtime_.Run(kernel_graph.get());
bool ret = runtime_.Run(kernel_graph.get(), false);
if (!ret) {
MS_LOG(EXCEPTION) << "Run graph failed";
}

@ -191,9 +191,9 @@ void GPUSession::Execute(const std::shared_ptr<KernelGraph> &kernel_graph) const
auto runtime_instance = device::KernelRuntimeManager::Instance().GetSingleKernelRuntime(kGPUDevice, device_id_);
MS_EXCEPTION_IF_NULL(runtime_instance);
#ifdef ENABLE_DEBUGGER
if (!runtime_instance->Run(kernel_graph.get(), debugger_.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false, debugger_.get())) {
#else
if (!runtime_instance->Run(kernel_graph.get())) {
if (!runtime_instance->Run(kernel_graph.get(), false)) {
#endif
MS_LOG(EXCEPTION) << "GPU execute graph failed!";
}

@ -454,10 +454,7 @@ DeviceAddressPtr AscendKernelRuntime::CreateDeviceAddress(void *device_ptr, size
return std::make_shared<AscendDeviceAddress>(device_ptr, device_size, format, type_id);
}
bool AscendKernelRuntime::Load(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
bool AscendKernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) {
if (!is_task_sink) {
return true;
}
@ -609,17 +606,14 @@ void AscendKernelRuntime::DebugTaskIdName(GraphId graph_id) {
}
}
bool AscendKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool AscendKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
bool ret = false;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
#if defined(_WIN32) || defined(_WIN64)
auto start_time = std::chrono::steady_clock::now();
#else
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
#endif
bool is_task_sink = context_ptr->get_param<bool>(MS_CTX_ENABLE_TASK_SINK);
if (is_task_sink) {
ret = RunTask(graph);
} else {

@ -44,8 +44,8 @@ class AscendKernelRuntime : public KernelRuntime {
bool GenTask(const session::KernelGraph *graph);
bool LoadTask(const session::KernelGraph *graph);
bool RunTask(const session::KernelGraph *graph);
bool Load(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Load(session::KernelGraph *graph, bool is_task_sink) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void ClearGraphRuntimeResource(uint32_t graph_id, const std::vector<AnfNodePtr> &inputs,
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;

@ -287,7 +287,7 @@ void CPUKernelRuntime::DecreaseSummaryRefCount(const session::NamedSummaryOutput
resource_manager_.DecreaseSummaryRefCount(summary_outputs);
}
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, Debugger *debugger) {
bool CPUKernelRuntime::Run(session::KernelGraph *kernel_graph, bool is_task_sink, Debugger *debugger) {
MS_EXCEPTION_IF_NULL(kernel_graph);
resource_manager_.IncreaseAddressRefCount(kernel_graph);

@ -36,7 +36,7 @@ class CPUKernelRuntime : public KernelRuntime {
~CPUKernelRuntime() override = default;
bool Init() override { return true; }
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
void AssignKernelAddress(session::KernelGraph *kernel_graph);
void BindInputOutput(session::KernelGraph *kernel_graph, const std::vector<tensor::TensorPtr> &inputs,
VectorRef *outputs);

@ -433,7 +433,7 @@ void GPUKernelRuntime::AssignMemory(session::KernelGraph *graph) {
}
}
bool GPUKernelRuntime::Run(session::KernelGraph *graph, Debugger *debugger) {
bool GPUKernelRuntime::Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger) {
struct timeval start_time, end_time;
(void)gettimeofday(&start_time, nullptr);
bool ret = true;

@ -42,7 +42,7 @@ class GPUKernelRuntime : public KernelRuntime {
const std::unordered_set<ValueNodePtr> &value_nodes,
const std::vector<CNodePtr> &execution_order) override;
void AssignMemory(session::KernelGraph *graph) override;
bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) override;
#ifdef ENABLE_DUMP_E2E
bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr) override;
#endif

@ -40,7 +40,7 @@ KernelRuntime::~KernelRuntime() {
#endif
}
bool KernelRuntime::Load(session::KernelGraph *graph) { return true; }
bool KernelRuntime::Load(session::KernelGraph *graph, bool is_task_sink) { return true; }
bool KernelRuntime::DumpData(mindspore::session::KernelGraph *graph, Debugger *debugger) {
if (graph != nullptr) {

@ -59,8 +59,8 @@ class KernelRuntime {
bool DumpDataEnabled();
bool DumpDataEnabledIteration();
virtual bool DumpData(session::KernelGraph *graph, Debugger *debugger = nullptr);
virtual bool Load(session::KernelGraph *graph);
virtual bool Run(session::KernelGraph *graph, Debugger *debugger = nullptr) = 0;
virtual bool Load(session::KernelGraph *graph, bool is_task_sink);
virtual bool Run(session::KernelGraph *graph, bool is_task_sink, Debugger *debugger = nullptr) = 0;
bool LaunchKernel(const session::KernelGraph *graph);
bool LaunchTaskBasedOnSingleKernel(kernel::KernelModPtr kernel_mod_ptr, const AddressPtrList &kernel_inputs,
const AddressPtrList &kernel_outputs,

Loading…
Cancel
Save