Support ms_function + heterogenous

pull/13049/head
tanghuikang 4 years ago
parent 0c3152fe6b
commit dac64f30ee

@ -321,15 +321,6 @@ bool AscendKernelRuntime::LoadData(mindspore::session::KernelGraph *graph) {
return true;
}
bool AscendKernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
if (AnfAlgo::OutputAddrExist(kernel, index)) {
auto address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == DeviceAddressType::kAscend;
}
return false;
}
bool AscendKernelRuntime::KernelMemNotReuse(const AnfNodePtr &node) {
bool need_dump = false;
auto &dump_json_parser = DumpJsonParser::GetInstance();

@ -57,13 +57,13 @@ class AscendKernelRuntime : public KernelRuntime {
void *context() const override { return rt_context_; }
void PreInit() override;
uint64_t GetAvailableMemMaxSize() const override;
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kAscend; };
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }
protected:
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) override;
bool NodeOutputDeviceAddressExist(const AnfNodePtr &node, size_t index) override;
bool KernelMemNotReuse(const AnfNodePtr &node) override;
void KernelLaunchProfiling(const std::string &kernel_name) override;

@ -46,6 +46,7 @@ class CPUKernelRuntime : public KernelRuntime {
void DecreaseSummaryRefCount(const session::NamedSummaryOutputs &summary_outputs);
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kCPU; };
protected:
bool SyncStream() override { return true; };

@ -47,6 +47,7 @@ class GPUKernelRuntime : public KernelRuntime {
bool Run(session::KernelGraph *graph, bool is_task_sink) override;
bool GenDynamicKernel(const session::KernelGraph *graph) override { return true; }
bool RunDynamicKernelAsync(const session::KernelGraph *graph) override { return true; }
DeviceAddressType GetTargetDeviceAddressType() const override { return DeviceAddressType::kGPU; };
void *compute_stream() const override { return stream_; }
void *communication_stream() const override { return communication_stream_; }

@ -49,7 +49,9 @@ bool KernelRuntime::LoadData(session::KernelGraph *graph) { return false; }
bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
if (AnfAlgo::OutputAddrExist(kernel, index)) {
return true;
const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == GetTargetDeviceAddressType();
}
return false;
}
@ -173,7 +175,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
MS_EXCEPTION_IF_NULL(input_tensors[input_index]);
auto output_address =
std::dynamic_pointer_cast<device::DeviceAddress>(input_tensors[input_index]->device_address());
if (output_address != nullptr) {
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(output_address, index, item.get());
continue;
}
@ -637,7 +639,8 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_LOG(WARNING) << "Tensor is null";
return;
}
if (tensor->device_address() != nullptr) {
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->DeviceType() == GetTargetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;

@ -108,6 +108,7 @@ class KernelRuntime {
virtual uint64_t GetAvailableMemMaxSize() const { return 0; }
void AddBufferPtr(std::shared_ptr<char[]> ptr) { buffer_ptrs_.push_back(ptr); }
void FreeAndClearBufferPtrs() { buffer_ptrs_.clear(); }
virtual DeviceAddressType GetTargetDeviceAddressType() const = 0;
virtual void *compute_stream() const { return nullptr; }
virtual void *communication_stream() const { return nullptr; }

Loading…
Cancel
Save