diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc index b6ccad75ae..fb8fb05502 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.cc @@ -34,6 +34,7 @@ #include "pipeline/jit/action.h" #include "utils/context/graph_kernel_flags.h" #include "vm/segment_runner.h" +#include "utils/ms_context.h" #if ENABLE_GPU #include "runtime/device/gpu/kernel_info_setter.h" #endif @@ -806,6 +807,19 @@ std::vector GetReduceAxis(const AnfNodePtr &node) { return axis; } +kernel::Processor GetProcessorFromContext() { + kernel::Processor processor = kernel::Processor::UNKNOWN; + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + auto device_info = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + if (device_info == kGPUDevice) { + processor = kernel::Processor::CUDA; + } else if (device_info == kAscendDevice) { + processor = kernel::Processor::AICORE; + } + return processor; +} + CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) { // Limitation: 1. Node's attributes should be set out of this function; 2. only one output. MS_EXCEPTION_IF_NULL(out_info.type); @@ -862,7 +876,7 @@ CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr & info_builder.SetInputsDeviceType(input_types); info_builder.SetOutputsFormat(output_formats); info_builder.SetOutputsDeviceType(output_types); - info_builder.SetProcessor(kernel::Processor::CUDA); + info_builder.SetProcessor(GetProcessorFromContext()); info_builder.SetKernelType(KernelType::AKG_KERNEL); info_builder.SetFusionType(kernel::FusionType::OPAQUE); auto selected_info = info_builder.Build(); diff --git a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h index 9b7810e71e..19001037c7 100644 --- a/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h +++ b/mindspore/ccsrc/backend/optimizer/graph_kernel/graph_kernel_helper.h @@ -84,6 +84,7 @@ TypePtr GetType(const AnfNodePtr &node); ShapeVector GetShape(const AnfNodePtr &node); ShapeVector GetDeviceShape(const AnfNodePtr &node); std::vector GetReduceAxis(const AnfNodePtr &node); +kernel::Processor GetProcessorFromContext(); CNodePtr CreateCNode(const std::vector &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);