diff --git a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc index 249450c193..accd742976 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc @@ -38,7 +38,7 @@ void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr } std::shared_ptr CPUKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); + auto kernel_info = dynamic_cast(apply_kernel->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(kernel_build_Info); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc index 3820089e35..4a0191abd7 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc @@ -137,7 +137,7 @@ std::pair GpuKernelFactory::GpuKernelAttrCheck(const std::string & } GpuKernel *GpuKernelFactory::Create(const std::string &kernel_name, const CNodePtr &apply_kernel) { - auto kernel_info = apply_kernel->kernel_info(); + auto kernel_info = dynamic_cast(apply_kernel->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); const KernelBuildInfo *kernel_build_Info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(kernel_build_Info); diff --git a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc index 9f44eb9d89..0c2667e4d9 100644 --- a/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc +++ b/mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc @@ -63,7 +63,7 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr kernel::KernelBuildInfoPtr GetKernelBuildInfo(const CNodePtr &cast, const string &format, TypeId input_type, TypeId output_type) { MS_EXCEPTION_IF_NULL(cast); - auto kernel_info = cast->kernel_info(); + auto kernel_info = dynamic_cast(cast->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto cast_build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(cast_build_info); diff --git a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc index a485b196af..133a7e764a 100644 --- a/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc +++ b/mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc @@ -23,8 +23,8 @@ namespace { bool CheckEqualKernelBuildInfo(const AnfNodePtr &main, const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(main); MS_EXCEPTION_IF_NULL(node); - auto main_kernel_info = main->kernel_info(); - auto node_kernel_info = node->kernel_info(); + auto main_kernel_info = dynamic_cast(main->kernel_info()); + auto node_kernel_info = dynamic_cast(node->kernel_info()); if (main_kernel_info == nullptr && node_kernel_info == nullptr) { return true; } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index 0e5af203bc..8ed290cc13 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -338,7 +338,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t if (!AnfAlgo::IsRealKernel(node)) { return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -360,7 +360,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i if (!IsRealKernel(node)) { GetPrevNodeOutputFormat(node, input_idx); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -467,7 +467,7 @@ std::vector AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, input_idx); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -486,7 +486,7 @@ std::vector AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod if (!IsRealKernel(node)) { return GetPrevNodeOutputReshapeType(node, output_idx); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -546,7 +546,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, output_idx); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -567,7 +567,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_ if (!IsRealKernel(node)) { return GetPrevNodeOutputDeviceDataType(node, 0); } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -597,7 +597,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node"; } } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetOutputAddr(output_idx); if (addr == nullptr) { @@ -619,7 +619,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node."; } } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetMutableOutputAddr(output_idx); if (addr == nullptr) { @@ -636,7 +636,7 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_ MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ " << GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]"; } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->OutputAddrExist(output_idx); } @@ -656,7 +656,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode // set output device addr of anf_node void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetOutputAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; @@ -666,7 +666,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out // set workspace device addr of anf_node void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) { MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail"; @@ -676,7 +676,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t // get workspace device addr of anf_node DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto addr = kernel_info->GetWorkspaceAddr(output_idx); if (addr == nullptr) { @@ -720,7 +720,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); @@ -731,7 +731,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) { // get KernelBuildType of node, such as ATT,RT,FWK and so on KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); // select_kernel_build_info() has checked whether return pointer is null auto build_info = kernel_info->select_kernel_build_info(); @@ -741,7 +741,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) { kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -750,7 +750,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) { kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); auto build_info = kernel_info->select_kernel_build_info(); MS_EXCEPTION_IF_NULL(build_info); @@ -760,7 +760,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) { // set select kernel_build_info void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->set_select_kernel_build_info(select_kernel_build_info); } @@ -768,7 +768,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel // get select kernel_build_info KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->GetMutableSelectKernelBuildInfo(); } @@ -776,7 +776,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt // get kernelMode KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->MutableKernelMod(); } @@ -784,7 +784,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) { // set kernel mod void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_kernel_mod(kernel_mod); } @@ -850,42 +850,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) { void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_id(stream_id); } uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_id(); } void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_stream_distinction_label(stream_label); } uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->stream_distinction_label(); } void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); kernel_info->set_graph_id(graph_id); } uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) { MS_EXCEPTION_IF_NULL(node); - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->graph_id(); } @@ -913,7 +913,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) { if (node->isa()) { return false; } - auto kernel_info = node->kernel_info(); + auto kernel_info = dynamic_cast(node->kernel_info()); MS_EXCEPTION_IF_NULL(kernel_info); return kernel_info->is_feature_map(); } diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 6bfc714d66..d5e8016a29 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -38,6 +38,8 @@ namespace mindspore { namespace session { using AnfVisitFuncion = std::function; using KernelWithIndex = std::pair; +using DeviceAddress = device::DeviceAddress; +using DeviceAddressPtr = device::DeviceAddressPtr; class AnfRuntimeAlgorithm { public: // get input_anf_node's real kernel by recurse diff --git a/mindspore/ccsrc/backend/session/gpu_session.cc b/mindspore/ccsrc/backend/session/gpu_session.cc index 1f109e0a6a..14e30c1a44 100644 --- a/mindspore/ccsrc/backend/session/gpu_session.cc +++ b/mindspore/ccsrc/backend/session/gpu_session.cc @@ -121,7 +121,7 @@ void GPUSession::LoadInputData(const std::shared_ptr &kernel_graph, if (input_node->isa() && AnfAlgo::OutputAddrExist(input_node, 0)) { auto pk_node = input_node->cast(); auto device_address = AnfAlgo::GetMutableOutputAddr(pk_node, 0); - auto tensor_address = tensor->device_address(); + auto tensor_address = std::dynamic_pointer_cast(tensor->device_address()); bool need_sync = false; if (ms_context->enable_pynative_infer()) { if (tensor_address == nullptr || tensor_address != device_address) { diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index a7960c4695..117e48fbb8 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -230,13 +230,14 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr &graph, // set the kernel info of parameter auto kernel_build_info_builder = std::make_shared(); MS_EXCEPTION_IF_NULL(input_tensor); - if (input_tensor->device_address().get() == nullptr) { + auto device_address = std::dynamic_pointer_cast(input_tensor->device_address()); + if (device_address == nullptr) { kernel_build_info_builder->SetOutputsFormat(std::vector{kOpFormat_DEFAULT}); TypeId param_init_data_type = AnfAlgo::IsParameterWeight(param) ? kTypeUnknown : input_tensor->data_type(); kernel_build_info_builder->SetOutputsDeviceType(std::vector{param_init_data_type}); } else { - kernel_build_info_builder->SetOutputsFormat(std::vector{input_tensor->device_address()->format()}); - kernel_build_info_builder->SetOutputsDeviceType(std::vector{input_tensor->device_address()->type_id()}); + kernel_build_info_builder->SetOutputsFormat(std::vector{device_address->format()}); + kernel_build_info_builder->SetOutputsDeviceType(std::vector{device_address->type_id()}); } AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_builder->Build(), param.get()); // construct abstract of parameter @@ -319,7 +320,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const if (ref_real_node->isa() && node_graph->IsInternalOutput(ref_real_node) && node_graph->IsFinalOutputKernel(ref_real_node)) { auto kernel_info = ref_real_node->kernel_info(); - if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { + if (kernel_info == nullptr || !kernel_info->has_build_info()) { MS_LOG(INFO) << "No kernel info"; return; } @@ -330,9 +331,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const } auto format = AnfAlgo::GetOutputFormat(ref_real_node, ref_real_node_index); auto type = AnfAlgo::GetOutputDeviceDataType(ref_real_node, ref_real_node_index); - parameter->set_kernel_info(std::make_shared()); - auto d_kernel_info = parameter->kernel_info(); + auto d_kernel_info = std::make_shared(); MS_EXCEPTION_IF_NULL(d_kernel_info); + parameter->set_kernel_info(d_kernel_info); kernel::KernelBuildInfo::KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({type}); builder.SetOutputsFormat({format}); diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index c7f2e2b14d..42d372cefb 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -128,7 +128,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr return; } auto kernel_info = node->kernel_info(); - if (kernel_info == nullptr || kernel_info->select_kernel_build_info() == nullptr) { + if (kernel_info == nullptr || !kernel_info->has_build_info()) { return; } @@ -179,7 +179,7 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa // print parameters' type and shape PrintNodeOutputType(buffer, p); auto kernel_info = p->kernel_info(); - if (kernel_info != nullptr && kernel_info->select_kernel_build_info() != nullptr) { + if (kernel_info != nullptr && kernel_info->has_build_info()) { buffer << " : "; auto type = AnfAlgo::GetOutputDeviceDataType(p, 0); auto format = AnfAlgo::GetOutputFormat(p, 0); diff --git a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc index 42e856d112..c76f96728f 100644 --- a/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc +++ b/mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc @@ -362,8 +362,7 @@ void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vectorsecond) { - if (node_user.first->kernel_info() == nullptr || - node_user.first->kernel_info()->select_kernel_build_info() == nullptr) { + if (node_user.first->kernel_info() == nullptr || !node_user.first->kernel_info()->has_build_info()) { // maybe not a real kernel. continue; } diff --git a/mindspore/ccsrc/runtime/device/device_address.h b/mindspore/ccsrc/runtime/device/device_address.h index 879caf45fc..32f5fcced9 100644 --- a/mindspore/ccsrc/runtime/device/device_address.h +++ b/mindspore/ccsrc/runtime/device/device_address.h @@ -21,8 +21,7 @@ #include #include #include "ir/dtype.h" - -using std::string; +#include "ir/device_sync.h" namespace mindspore { namespace device { @@ -51,15 +50,12 @@ namespace device { enum class DeviceAddressStatus { kInDevice, kInHost, kInDeviceToHost, kInHostToDevice }; enum class DeviceAddressType { kUnknown, kAscend, kCPU, kGPU }; -class DeviceAddress { +class DeviceAddress : public mindspore::DeviceSync { public: explicit DeviceAddress(void *ptr, size_t size) : ptr_(ptr), size_(size) {} explicit DeviceAddress(void *ptr, size_t size, const string &format, TypeId type_id) : ptr_(ptr), size_(size), format_(format), type_id_(type_id) {} virtual ~DeviceAddress() { ptr_ = nullptr; } - virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; - virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, - const void *host_ptr) const = 0; const void *GetPtr() const { return ptr_; } size_t GetSize() const { return size_; } std::string format() const { return format_; } diff --git a/mindspore/ccsrc/runtime/device/kernel_info.h b/mindspore/ccsrc/runtime/device/kernel_info.h index b8ab985c86..baded9d9a3 100644 --- a/mindspore/ccsrc/runtime/device/kernel_info.h +++ b/mindspore/ccsrc/runtime/device/kernel_info.h @@ -19,6 +19,7 @@ #include #include +#include "ir/kernel_info_dev.h" #include "backend/kernel_compiler/kernel_build_info.h" #include "runtime/device/ascend/ascend_device_address.h" #include "backend/kernel_compiler/kernel.h" @@ -27,7 +28,7 @@ namespace mindspore { const uint32_t kInvalidGraphId = UINT32_MAX; const uint32_t kInvalidDistincLabel = UINT32_MAX; namespace device { -class KernelInfo { +class KernelInfo : public KernelInfoDevice { public: KernelInfo() { kernel_mod_ = nullptr; @@ -41,6 +42,7 @@ class KernelInfo { } virtual ~KernelInfo() = default; + bool has_build_info() const override { return select_kernel_build_info() != nullptr; } const kernel::KernelBuildInfo *select_kernel_build_info() const; kernel::KernelBuildInfoPtr GetMutableSelectKernelBuildInfo() const; void set_select_kernel_build_info(const kernel::KernelBuildInfoPtr &select_kernel_build_info) { diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 49fddcae45..d5fd00da5b 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -214,8 +214,10 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector auto output_size = AnfAlgo::GetOutputTensorNum(item); for (size_t index = 0; index < output_size; index++) { MS_EXCEPTION_IF_NULL(input_tensors[input_index]); - if (input_tensors[input_index]->device_address().get() != nullptr) { - AnfAlgo::SetOutputAddr(input_tensors[input_index]->device_address(), index, item.get()); + auto output_address = + std::dynamic_pointer_cast(input_tensors[input_index]->device_address()); + if (output_address != nullptr) { + AnfAlgo::SetOutputAddr(output_address, index, item.get()); continue; } TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index 9df4d71c40..c1a28d57f1 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -27,8 +27,9 @@ #include #include "base/base.h" -#include "debug/info.h" +#include "ir/kernel_info_dev.h" #include "ir/scope.h" +#include "debug/info.h" // A MindSpore ANF IR defined here. // with BNF followed: @@ -71,12 +72,6 @@ class BaseRef; class Var; using VarPtr = std::shared_ptr; -namespace device { -class KernelInfo; -} // namespace device -using KernelInfoDevice = device::KernelInfo; -using KernelInfoDevicePtr = std::shared_ptr; - class AnfVisitor; class ParamValue; diff --git a/mindspore/core/ir/device_sync.h b/mindspore/core/ir/device_sync.h new file mode 100644 index 0000000000..a6bbe92233 --- /dev/null +++ b/mindspore/core/ir/device_sync.h @@ -0,0 +1,38 @@ +/** + * Copyright 2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ +#define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ + +#include +#include +#include + +#include "ir/dtype/type.h" + +using std::string; + +namespace mindspore { +// Interface for data synchornize between device and host. +class DeviceSync { + public: + virtual bool SyncDeviceToHost(const std::vector &shape, size_t size, TypeId type, void *host_ptr) const = 0; + virtual bool SyncHostToDevice(const std::vector &shape, size_t size, TypeId type, + const void *host_ptr) const = 0; +}; +using DeviceSyncPtr = std::shared_ptr; +} // namespace mindspore +#endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_ diff --git a/mindspore/core/ir/kernel_info_dev.h b/mindspore/core/ir/kernel_info_dev.h new file mode 100644 index 0000000000..87c717bdcb --- /dev/null +++ b/mindspore/core/ir/kernel_info_dev.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ +#define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ + +#include + +namespace mindspore { +// Interface for device kernel program information. +class KernelInfoDevice { + public: + // If kernel program was built and build info is set. + virtual bool has_build_info() const = 0; +}; +using KernelInfoDevicePtr = std::shared_ptr; +} // namespace mindspore + +#endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_ diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index 6c966b32e3..8275acbbc5 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -326,7 +326,7 @@ Tensor::Tensor(const Tensor &tensor) data_(tensor.data_), dirty_(tensor.dirty_), id_(tensor.id_), - device_address_(tensor.device_address_) {} + device_sync_(tensor.device_sync_) {} Tensor::Tensor(const Tensor &tensor, TypeId data_type) : MetaTensor(data_type, tensor.shape_), @@ -334,7 +334,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type) data_(MakeTensorData(data_type, tensor.shape_, tensor.data_->data(), tensor.data_type_)), dirty_(tensor.dirty_), id_(tensor.id_), - device_address_(tensor.device_address_) {} + device_sync_(tensor.device_sync_) {} Tensor::Tensor(TypeId data_type, const std::vector &shape, TensorDataPtr data) : MetaTensor(data_type, shape), data_(std::move(data)), id_(MakeId()) {} @@ -379,10 +379,10 @@ bool Tensor::ValueEqual(const Tensor &tensor) const { Tensor &Tensor::AssignValue(const Tensor &tensor) { if (this != &tensor) { MetaTensor::operator=(tensor); - dirty_ = tensor.is_dirty(); - device_address_ = tensor.device_address(); + dirty_ = tensor.dirty_; + device_sync_ = tensor.device_sync_; data_ = tensor.data_; - id_ = tensor.id(); + id_ = tensor.id_; } return *this; } @@ -425,8 +425,8 @@ std::string Tensor::ToStringRepr() const { } void Tensor::data_sync() const { - if (device_address_ != nullptr) { - if (!device_address_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { + if (device_sync_ != nullptr) { + if (!device_sync_->SyncDeviceToHost(shape(), static_cast(data().nbytes()), data_type(), data_c())) { MS_LOG(EXCEPTION) << "SyncDeviceToHost when asnumpy."; } } diff --git a/mindspore/core/ir/tensor.h b/mindspore/core/ir/tensor.h index f2ed2c1609..727fb0fdd8 100644 --- a/mindspore/core/ir/tensor.h +++ b/mindspore/core/ir/tensor.h @@ -23,15 +23,13 @@ #include #include "Eigen/Core" -#include "runtime/device/device_address.h" +#include "ir/device_sync.h" #include "ir/meta_tensor.h" #include "include/ms_tensor.h" #include "utils/log_adapter.h" using float16 = Eigen::half; -using mindspore::device::DeviceAddress; -using DeviceAddressPtr = std::shared_ptr; // brief mindspore namespace. // // mindspore namespace is the top level namespace of MindSpore project. @@ -222,8 +220,8 @@ class Tensor : public MetaTensor { bool is_dirty() const { return dirty_; } void set_dirty(const bool dirty) { dirty_ = dirty; } - DeviceAddressPtr device_address() const { return device_address_; } - void set_device_address(const DeviceAddressPtr &device_address) { device_address_ = device_address; } + DeviceSyncPtr device_address() const { return device_sync_; } + void set_device_address(const DeviceSyncPtr &device_sync) { device_sync_ = device_sync; } std::string id() const { return id_; } @@ -234,7 +232,7 @@ class Tensor : public MetaTensor { TensorDataPtr data_{nullptr}; bool dirty_{true}; std::string id_{""}; - DeviceAddressPtr device_address_{nullptr}; + DeviceSyncPtr device_sync_{nullptr}; }; using TensorPtr = std::shared_ptr; using TensorPtrList = std::vector>; diff --git a/mindspore/core/ir/tensor_py.cc b/mindspore/core/ir/tensor_py.cc index f5f83d0e07..ef78d2720e 100644 --- a/mindspore/core/ir/tensor_py.cc +++ b/mindspore/core/ir/tensor_py.cc @@ -22,7 +22,6 @@ #include #include -#include "runtime/device/device_address.h" #include "pybind_api/api_register.h" #include "pybind_api/export_flags.h" #include "abstract/abstract_value.h" diff --git a/mindspore/core/ir/tensor_py.h b/mindspore/core/ir/tensor_py.h index 18ee547071..f917584977 100644 --- a/mindspore/core/ir/tensor_py.h +++ b/mindspore/core/ir/tensor_py.h @@ -81,8 +81,6 @@ struct type_caster : public npy_scalar_caster { } // namespace detail } // namespace pybind11 -using mindspore::device::DeviceAddress; -using DeviceAddressPtr = std::shared_ptr; // brief mindspore namespace. // // mindspore namespace is the top level namespace of Mindsporeession project. diff --git a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc index e81870fd4f..ac38e5427e 100644 --- a/tests/ut/cpp/session/anf_runtime_algorithm_test.cc +++ b/tests/ut/cpp/session/anf_runtime_algorithm_test.cc @@ -255,7 +255,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) { AnfAlgo::SetOutputInferTypeAndShape({kNumberTypeFloat32, kNumberTypeFloat32}, {shape, shape}, add.get()); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -274,7 +274,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -293,7 +293,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) { auto pre_add = kernel_graph->NewCNode(pre_node_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -373,7 +373,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) { MS_EXCEPTION_IF_NULL(add); add->set_abstract(tuple_abstract); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC, kOpFormat_FRAC_NZ}); @@ -404,7 +404,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsFormat({kOpFormat_NCHW, kOpFormat_NCHW, kOpFormat_NHWC}); @@ -457,7 +457,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -474,7 +474,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetInputsDeviceType({kFloat32->type_id(), kFloat16->type_id()}); @@ -492,7 +492,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) { auto pre_add = kernel_graph->NewCNode(pre_add_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetOutputsDeviceType({kFloat32->type_id()}); @@ -513,7 +513,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -528,7 +528,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) { auto pre_add = kernel_graph->NewCNode(pre_add_inputs); MS_EXCEPTION_IF_NULL(pre_add); pre_add->set_kernel_info(std::make_shared()); - auto d_kernel_info = pre_add->kernel_info(); + auto d_kernel_info = dynamic_cast(pre_add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -561,7 +561,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); int *addr = nullptr; auto device_address = std::make_shared(addr, 1); @@ -643,7 +643,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetKernelType(AKG_KERNEL); @@ -659,7 +659,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetProcessor(kernel::AICORE); @@ -675,7 +675,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); KernelBuildInfoBuilder builder; builder.SetFusionType(kernel::CONVLUTION); @@ -703,7 +703,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); d_kernel_info->set_kernel_mod(nullptr); EXPECT_EQ(AnfAlgo::GetKernelMod(add), nullptr); @@ -779,7 +779,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) { auto add = kernel_graph->NewCNode(inputs); MS_EXCEPTION_IF_NULL(add); add->set_kernel_info(std::make_shared()); - auto d_kernel_info = add->kernel_info(); + auto d_kernel_info = dynamic_cast(add->kernel_info()); MS_EXCEPTION_IF_NULL(d_kernel_info); d_kernel_info->set_stream_id(0); EXPECT_EQ(AnfAlgo::GetStreamId(add), 0); diff --git a/tests/ut/cpp/session/kernel_graph_test.cc b/tests/ut/cpp/session/kernel_graph_test.cc index fb78a150b6..f24036b4aa 100644 --- a/tests/ut/cpp/session/kernel_graph_test.cc +++ b/tests/ut/cpp/session/kernel_graph_test.cc @@ -42,7 +42,7 @@ TEST_F(KernelGraphTest, NewValueNode) { auto x_abstract = std::make_shared(kFloat32, shape); add_value->set_abstract(x_abstract); add_value->set_kernel_info(std::make_shared()); - auto mutable_kernel_info = add_value->kernel_info(); + auto mutable_kernel_info = dynamic_cast(add_value->kernel_info()); MS_EXCEPTION_IF_NULL(mutable_kernel_info); std::shared_ptr builder = std::make_shared(); builder->SetOutputsFormat({kOpFormat_FRAC_Z});