diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc index e3e97a1b63..da9a522945 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.cc @@ -321,15 +321,6 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) { return true; } -bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { - if (AnfAlgo::OutputAddrExist(kernel, index)) { - auto address = AnfAlgo::GetOutputAddr(kernel, index); - MS_EXCEPTION_IF_NULL(address); - return address->DeviceType() == DeviceAddressType::kAscend; - } - return false; -} - bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) { bool need_dump = false; auto &dump_json_parser = DumpJsonParser::GetInstance(); diff --git a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h index b37c74de46..96ddd30c16 100644 --- a/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/ascend/ascend_kernel_runtime.h @@ -57,13 +57,13 @@ class AscendKernelRuntime : public KernelRuntime { void *context() const override { return rt_context_; } void PreInit() override; uint64_t GetAvailableMemMaxSize() const override; + DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; }; void *compute_stream() const override { return stream_; } void *communication_stream() const override { return communication_stream_; } protected: DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, TypeId type_id) override; - bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override; bool KernelMemNotReuse(const AnfNodePtr &node) override; void KernelLaunchProfiling(const std::string &kernel_name) override; diff --git a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h index 6eebb13809..0c40286536 100644 --- a/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/cpu/cpu_kernel_runtime.h @@ -46,6 +46,7 @@ class CPUKernelRuntime : public KernelRuntime { void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs); bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } + DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kCPU; }; protected: bool SyncStream() override { return true; }; diff --git a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h index b57021effa..a7f175d254 100644 --- a/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/gpu/gpu_kernel_runtime.h @@ -47,6 +47,7 @@ class GPUKernelRuntime : public KernelRuntime { bool Run(session::KernelGraph *graph, bool is_task_sink) override; bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; } bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; } + DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; }; void *compute_stream() const override { return stream_; } void *communication_stream() const override { return communication_stream_; } diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 96765ccc77..169f93490a 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -49,7 +49,9 @@ bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; } bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) { MS_EXCEPTION_IF_NULL(kernel); if (AnfAlgo::OutputAddrExist(kernel, index)) { - return true; + const auto &address = AnfAlgo::GetOutputAddr(kernel, index); + MS_EXCEPTION_IF_NULL(address); + return address->DeviceType() == GetTargetDeviceAddressType(); } return false; } @@ -173,7 +175,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector MS_EXCEPTION_IF_NULL(input_tensors[input_index]); auto output_address = std::dynamic_pointer_cast(input_tensors[input_index]->device_address()); - if (output_address != nullptr) { + if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) { AnfAlgo::SetOutputAddr(output_address, index, item.get()); continue; } @@ -637,7 +639,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const MS_LOG(WARNING) << "Tensor is null"; return; } - if (tensor->device_address() != nullptr) { + auto output_address = std::dynamic_pointer_cast(tensor->device_address()); + if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) { AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), output_idx++, value_node.get()); continue; diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 124d2942f6..f4f73e3f15 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -108,6 +108,7 @@ class KernelRuntime { virtual uint64_t GetAvailableMemMaxSize() const { return 0; } void AddBufferPtr(std::shared_ptr ptr) { buffer_ptrs_.push_back(ptr); } void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); } + virtual DeviceAddressType GetTargetDeviceAddressType() const = 0; virtual void *compute_stream() const { return nullptr; } virtual void *communication_stream() const { return nullptr; }