diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc index b7010e69dc..fb437730f2 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc @@ -19,6 +19,8 @@ #include #include #include +#include +#include #include "ir/anf.h" #include "ir/func_graph.h" #include "base/core_ops.h" @@ -480,6 +482,28 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) { return 1; } +size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) { + MS_EXCEPTION_IF_NULL(node); + if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { + MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" + << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; + } + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); + } + size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); + std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); + auto format = AnfAlgo::GetOutputFormat(node, output_index); + if (shape.empty() && format != kOpFormat_DEFAULT) { + shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); + shape = trans::TransShapeToDevice(shape, format); + } + // scalar's output shape is a empty vector + size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); + return tensor_size; +} + std::vector AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) { MS_EXCEPTION_IF_NULL(node); if (!AnfAlgo::IsRealKernel(node)) { diff --git a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h index 58dba2e36c..9d3fcf8f70 100644 --- a/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h +++ b/mindspore/ccsrc/backend/session/anf_runtime_algorithm.h @@ -105,6 +105,8 @@ class AnfRuntimeAlgorithm { static size_t GetInputTensorNum(const AnfNodePtr &node); // get the num of output real_kernel(which can be build and run in device) static size_t GetOutputTensorNum(const AnfNodePtr &node); + // Get the memory size of output tensor of node. + static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index); // get all outputs format select of anf node static std::vector GetAllOutputFormats(const AnfNodePtr &node); // get all inputs format select of anf node diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.cc b/mindspore/ccsrc/runtime/device/kernel_runtime.cc index 6e7be39253..e3d9ab1046 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.cc +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.cc @@ -16,7 +16,6 @@ #include "runtime/device/kernel_runtime.h" #include -#include #include #include #include "backend/optimizer/common/helper.h" @@ -57,28 +56,6 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_ return false; } -size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) { - MS_EXCEPTION_IF_NULL(node); - if (output_index >= AnfAlgo::GetOutputTensorNum(node)) { - MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size [" - << AnfAlgo::GetOutputTensorNum(node) << "] of node!"; - } - TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index); - if (output_type_id == kTypeUnknown) { - output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index); - } - size_t type_size = GetTypeByte(TypeIdToType(output_type_id)); - std::vector shape = AnfAlgo::GetOutputDeviceShape(node, output_index); - auto format = AnfAlgo::GetOutputFormat(node, output_index); - if (shape.empty() && format != kOpFormat_DEFAULT) { - shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index)); - shape = trans::TransShapeToDevice(shape, format); - } - // scalar's output shape is a empty vector - size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies()); - return tensor_size; -} - void KernelRuntime::AssignMemory(session::KernelGraph *graph) { auto context_ptr = MsContext::GetInstance(); MS_EXCEPTION_IF_NULL(context_ptr); @@ -184,7 +161,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector if (output_type_id == kTypeUnknown) { output_type_id = AnfAlgo::GetOutputInferDataType(item, index); } - auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); auto device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); MS_EXCEPTION_IF_NULL(device_address); @@ -361,7 +338,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) { continue; } #endif - auto tensor_size = CountNodeDeviceMemorySize(item, index); + auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id); MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope(); if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) { @@ -656,7 +633,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const continue; } size_t tensor_size = tensor->data().nbytes(); - auto node_size = CountNodeDeviceMemorySize(value_node, output_idx); + auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); if (output_type_id == kTypeUnknown) { output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); diff --git a/mindspore/ccsrc/runtime/device/kernel_runtime.h b/mindspore/ccsrc/runtime/device/kernel_runtime.h index 1b9ddc94a6..d5c17b36cb 100644 --- a/mindspore/ccsrc/runtime/device/kernel_runtime.h +++ b/mindspore/ccsrc/runtime/device/kernel_runtime.h @@ -138,7 +138,6 @@ class KernelRuntime { bool LaunchKernelMod(const session::KernelGraph &graph); void LaunchKernelEvent(const std::vector>> &run_events, size_t index); static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs); - size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index); void RunOpAssignInputMemory(const std::vector &input_tensors, const session::KernelGraph *graph); void RunOpAssignOutputMemory(const AnfNodePtr &kernel); void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel); diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.cc b/mindspore/ccsrc/runtime/framework/graph_compiler.cc index add58f3f18..535f23c013 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.cc +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.cc @@ -15,29 +15,200 @@ */ #include "runtime/framework/graph_compiler.h" +#include +#include #include "runtime/framework/graph_scheduler.h" +#include "runtime/device/device_address.h" +#include "common/trans.h" +#include "utils/convert_utils.h" +#include "ir/tensor.h" namespace mindspore { namespace runtime { -void GraphCompiler::set_device_context(device::DeviceContext *device_context) { +namespace { +// Whether device address of anf node is valid and device address type +// is consistent with device type, for example, device address type +// DeviceAddressType::kGPU should be used on GPU device +bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) { + MS_EXCEPTION_IF_NULL(kernel); + MS_EXCEPTION_IF_NULL(device_context); + if (AnfAlgo::OutputAddrExist(kernel, index)) { + const auto &address = AnfAlgo::GetOutputAddr(kernel, index); + MS_EXCEPTION_IF_NULL(address); + return address->DeviceType() == device_context->GetDeviceAddressType(); + } + return false; +} + +void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(graph); + std::vector graph_inputs = graph->inputs(); + const std::vector &graph_valid_input = graph->valid_inputs(); + graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end()); + + // Anf nodes which need create device address. + std::vector nodes_list; + for (size_t i = 0; i < graph_inputs.size(); ++i) { + AnfNodePtr item = graph_inputs[i]; + MS_EXCEPTION_IF_NULL(item); + if (i < graph_valid_input.size() && !graph_valid_input[i]) { + continue; + } + + if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) { + std::vector outs = AnfAlgo::GetAllOutput(item); + for (const auto &out : outs) { + MS_EXCEPTION_IF_NULL(out); + if (!out->isa() || NodeDeviceAddressExist(device_context, out, 0)) { + continue; + } + nodes_list.push_back(out); + } + } + if (!item->isa() || NodeDeviceAddressExist(device_context, item, 0)) { + continue; + } + nodes_list.push_back(item); + } + + // Create device address for anf node in nodes_list + for (const auto &item : nodes_list) { + auto output_size = AnfAlgo::GetOutputTensorNum(item); + for (size_t index = 0; index < output_size; index++) { + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index); + // if graph output is a weight and doesn't link to any cnode, it's data type will be unknown + if (output_type_id == kTypeUnknown) { + MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph"; + continue; + } + + size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index); + auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size, + AnfAlgo::GetOutputFormat(item, index), output_type_id); + AnfAlgo::SetOutputAddr(device_address, index, item.get()); + } + } +} + +void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value, + size_t output_idx, const ValueNodePtr &value_node) { + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(node_value); + MS_EXCEPTION_IF_NULL(value_node); + const auto &ms_context = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(ms_context); + std::vector tensors; + TensorValueToTensor(node_value, &tensors); + + for (const auto &tensor : tensors) { + if (tensor == nullptr) { + MS_LOG(WARNING) << "Tensor is null"; + return; + } + auto output_address = std::dynamic_pointer_cast(tensor->device_address()); + if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) { + AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast(tensor->device_address()), output_idx++, + value_node.get()); + continue; + } + + size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx); + TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx); + if (output_type_id == kTypeUnknown) { + output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx); + } + std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx); + + device::DeviceAddressPtr address = + device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id); + MS_EXCEPTION_IF_NULL(address); + AnfAlgo::SetOutputAddr(address, output_idx, value_node.get()); + } +} + +void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(graph); + for (const ValueNodePtr &value_node : graph->graph_value_nodes()) { + MS_EXCEPTION_IF_NULL(value_node); + if (NodeDeviceAddressExist(device_context, value_node, 0)) { + continue; + } + + const auto &node_value = value_node->value(); + MS_EXCEPTION_IF_NULL(node_value); + if (node_value->isa() || node_value->isa()) { + CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node); + } else if (node_value->isa()) { + auto value = GetValue(node_value); + size_t tensor_size = value.size(); + auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8); + MS_EXCEPTION_IF_NULL(address); + + AnfAlgo::SetOutputAddr(address, 0, value_node.get()); + } + } +} + +void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(graph); + const std::vector &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto output_sizes = kernel_mod->GetOutputSizeList(); + for (size_t i = 0; i < output_sizes.size(); ++i) { + if (AnfAlgo::OutputAddrExist(kernel, i)) { + continue; + } + + std::string output_format = AnfAlgo::GetOutputFormat(kernel, i); + auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i); + auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type); + AnfAlgo::SetOutputAddr(device_address, i, kernel.get()); + } + } +} + +void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) { + MS_EXCEPTION_IF_NULL(device_context); + MS_EXCEPTION_IF_NULL(graph); + const std::vector &kernels = graph->execution_order(); + for (const auto &kernel : kernels) { + auto kernel_mod = AnfAlgo::GetKernelMod(kernel); + MS_EXCEPTION_IF_NULL(kernel_mod); + auto workspace_sizes = kernel_mod->GetWorkspaceSizeList(); + for (size_t i = 0; i < workspace_sizes.size(); ++i) { + auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown); + AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get()); + } + } +} +} // namespace + +void GraphCompiler::set_device_context(DeviceContext *device_context) { MS_EXCEPTION_IF_NULL(device_context); device_context_ = device_context; // The member variable 'session_' will be removed after removing session module. if (session_ == nullptr) { session_ = std::make_shared(); + const device::DeviceContextKey &device_context_key = device_context->device_context_key(); + session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_); } } GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) { MS_EXCEPTION_IF_NULL(session_); // Generate kernel graph. - auto graph = session_->ConstructKernelGraph(nodes, outputs); + KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs); MS_EXCEPTION_IF_NULL(graph); return CompileGraphImpl(graph); } -GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { +GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const { MS_EXCEPTION_IF_NULL(device_context_); // Optimization pass which is irrelevant to device type or format. device_context_->OptimizeGraphWithoutDeviceInfo(graph); @@ -51,6 +222,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) { // 'KernelMod' is real executive object of kernel. device_context_->CreateKernel(graph->execution_order()); + // Create device address for all anf nodes of graph. + CreateDeviceAddress(graph); // Transform graph to actor DAG, contains build and link. GraphScheduler::GetInstance().Transform(graph, device_context_); return graph->graph_id(); @@ -68,7 +241,7 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph } // Generate kernel graph. MS_EXCEPTION_IF_NULL(session_); - auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); + KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask); MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(device_context_); @@ -82,6 +255,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph // Generate 'KernelMod' for kernel in graph. device_context_->CreateKernel(graph->execution_order()); + // Create device address for all anf nodes of graph. + CreateDeviceAddress(graph); // Transform graph to actor DAG, contains build and link. GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep); run_op_graphs_[graph_info] = graph; @@ -101,5 +276,12 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const { } return iter->second; } + +void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const { + CreateParameterDeviceAddress(device_context_, graph); + CreateValueNodeDeviceAddress(device_context_, graph); + CreateKernelOutputDeviceAddress(device_context_, graph); + CreateKernelWorkspaceDeviceAddress(device_context_, graph); +} } // namespace runtime } // namespace mindspore diff --git a/mindspore/ccsrc/runtime/framework/graph_compiler.h b/mindspore/ccsrc/runtime/framework/graph_compiler.h index 1db27cc0e8..2222076e72 100644 --- a/mindspore/ccsrc/runtime/framework/graph_compiler.h +++ b/mindspore/ccsrc/runtime/framework/graph_compiler.h @@ -26,6 +26,7 @@ namespace mindspore { namespace runtime { +using device::DeviceContext; class GraphCompiler { public: static GraphCompiler &GetInstance() { @@ -35,7 +36,7 @@ class GraphCompiler { // Set device context which is initialized, the function must be called // before using GraphCompiler and after changing device type or device id. - void set_device_context(device::DeviceContext *device_context); + void set_device_context(DeviceContext *device_context); // Construct kernel graph from anf nodes list and compile kernel graph in Graph mode, // the detailed implementation of compiling graph is in 'CompileGraphImpl'. @@ -58,9 +59,12 @@ class GraphCompiler { // The implementation of compiling graph in Graph Mode, including optimizing graph, // setting operator info, creating kernel and transforming kernel graph to ActorSet. - GraphId CompileGraphImpl(const KernelGraphPtr &graph); + GraphId CompileGraphImpl(const KernelGraphPtr &graph) const; - device::DeviceContext *device_context_{nullptr}; + // Create device address for all anf nodes of graph. + void CreateDeviceAddress(const KernelGraphPtr &graph) const; + + DeviceContext *device_context_{nullptr}; // Single op kernel graph cache for PyNative mode. std::unordered_map run_op_graphs_; diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc index f2d0b224ca..cbb24a968e 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.cc @@ -50,6 +50,11 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const { address->ptr_ = nullptr; } +DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) const { + return std::make_shared(device_ptr, device_size, format, type_id); +} + void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { // Update Graph Dynamic Shape Attr. UpdateGraphDynamicShapeAttr(NOT_NULL(graph)); diff --git a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h index 3c451e2a2e..9123ade98d 100644 --- a/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h +++ b/mindspore/ccsrc/runtime/hardware/cpu/cpu_device_context.h @@ -18,6 +18,7 @@ #include #include +#include #include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context_manager.h" #include "runtime/device/memory_manager.h" @@ -36,6 +37,10 @@ class CPUDeviceContext : public DeviceContext { bool AllocateMemory(DeviceAddress *const &address, size_t size) const override; void FreeMemory(DeviceAddress *const &address) const override; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) const override; + DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; } + void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override; diff --git a/mindspore/ccsrc/runtime/hardware/device_context.h b/mindspore/ccsrc/runtime/hardware/device_context.h index de520996b3..5ee12e852c 100644 --- a/mindspore/ccsrc/runtime/hardware/device_context.h +++ b/mindspore/ccsrc/runtime/hardware/device_context.h @@ -63,6 +63,13 @@ class DeviceContext { return true; } + // Create concrete device address according different device type. + virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) const = 0; + + // Get device address type according different device type, such GPU, Ascend. + virtual DeviceAddressType GetDeviceAddressType() const = 0; + // The two functions below will be merged to one in the future. // General graph optimezer ignore device data type and format. virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {} @@ -90,6 +97,9 @@ class DeviceContext { // Devices that do not need stream could ignore the implementation of this function. virtual bool SyncStream(size_t stream_id = 0) { return true; } + // Get device_context_key_ to obtain device name and device id. + const DeviceContextKey &device_context_key() const { return device_context_key_; } + protected: DeviceContextKey device_context_key_; }; diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc index 48cfcd2c74..39e53e4ab2 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.cc @@ -165,6 +165,11 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector(device_ptr, device_size, format, type_id); +} + void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const { MS_EXCEPTION_IF_NULL(graph); // Operator fusion optimization. diff --git a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.h b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.h index b66dc486f5..35e11011d8 100644 --- a/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.h +++ b/mindspore/ccsrc/runtime/hardware/gpu/gpu_device_context.h @@ -19,6 +19,7 @@ #include #include +#include #include "runtime/hardware/device_context.h" #include "runtime/hardware/device_context_manager.h" #include "runtime/device/memory_manager.h" @@ -43,6 +44,10 @@ class GPUDeviceContext : public DeviceContext { bool AllocateContinuousMemory(const std::vector &addr_list, size_t total_size, const std::vector &size_list) const override; + DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format, + TypeId type_id) const override; + DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; } + // General graph optimezer ignore device data type and format. void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override; // Optimize the kernel graph according to device type, such format transform.