!14237 [GraphKernel] infer processor from ms_context inside function CreateCNode

From: @looop5
Reviewed-by: @gaoxiong1,@dylangeng
Signed-off-by: @dylangeng
pull/14237/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit e644a66973

@ -34,6 +34,7 @@
#include "pipeline/jit/action.h" #include "pipeline/jit/action.h"
#include "utils/context/graph_kernel_flags.h" #include "utils/context/graph_kernel_flags.h"
#include "vm/segment_runner.h" #include "vm/segment_runner.h"
#include "utils/ms_context.h"
#if ENABLE_GPU #if ENABLE_GPU
#include "runtime/device/gpu/kernel_info_setter.h" #include "runtime/device/gpu/kernel_info_setter.h"
#endif #endif
@ -806,6 +807,19 @@ std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node) {
return axis; return axis;
} }
kernel::Processor GetProcessorFromContext() {
kernel::Processor processor = kernel::Processor::UNKNOWN;
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
auto device_info = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
if (device_info == kGPUDevice) {
processor = kernel::Processor::CUDA;
} else if (device_info == kAscendDevice) {
processor = kernel::Processor::AICORE;
}
return processor;
}
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) { CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info) {
// Limitation: 1. Node's attributes should be set out of this function; 2. only one output. // Limitation: 1. Node's attributes should be set out of this function; 2. only one output.
MS_EXCEPTION_IF_NULL(out_info.type); MS_EXCEPTION_IF_NULL(out_info.type);
@ -862,7 +876,7 @@ CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &
info_builder.SetInputsDeviceType(input_types); info_builder.SetInputsDeviceType(input_types);
info_builder.SetOutputsFormat(output_formats); info_builder.SetOutputsFormat(output_formats);
info_builder.SetOutputsDeviceType(output_types); info_builder.SetOutputsDeviceType(output_types);
info_builder.SetProcessor(kernel::Processor::CUDA); info_builder.SetProcessor(GetProcessorFromContext());
info_builder.SetKernelType(KernelType::AKG_KERNEL); info_builder.SetKernelType(KernelType::AKG_KERNEL);
info_builder.SetFusionType(kernel::FusionType::OPAQUE); info_builder.SetFusionType(kernel::FusionType::OPAQUE);
auto selected_info = info_builder.Build(); auto selected_info = info_builder.Build();

@ -84,6 +84,7 @@ TypePtr GetType(const AnfNodePtr &node);
ShapeVector GetShape(const AnfNodePtr &node); ShapeVector GetShape(const AnfNodePtr &node);
ShapeVector GetDeviceShape(const AnfNodePtr &node); ShapeVector GetDeviceShape(const AnfNodePtr &node);
std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node); std::vector<int64_t> GetReduceAxis(const AnfNodePtr &node);
kernel::Processor GetProcessorFromContext();
CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info); CNodePtr CreateCNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph, const DataInfo &out_info);
void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node); void SetNodeAttrSafely(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);

Loading…
Cancel
Save