init device address for anf node

pull/14610/head
lizhenyu 4 years ago
parent 536d6bd1e5
commit a87b141cf9

@ -19,6 +19,8 @@
#include <map>
#include <set>
#include <unordered_set>
#include <functional>
#include <numeric>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "base/core_ops.h"
@ -480,6 +482,28 @@ size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
return 1;
}
size_t AnfRuntimeAlgorithm::GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
<< AnfAlgo::GetOutputTensorNum(node) << "] of node!";
}
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
auto format = AnfAlgo::GetOutputFormat(node, output_index);
if (shape.empty() && format != kOpFormat_DEFAULT) {
shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::TransShapeToDevice(shape, format);
}
// scalar's output shape is a empty vector
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
return tensor_size;
}
std::vector<std::string> AnfRuntimeAlgorithm::GetAllOutputFormats(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!AnfAlgo::IsRealKernel(node)) {

@ -105,6 +105,8 @@ class AnfRuntimeAlgorithm {
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get the num of output real_kernel(which can be build and run in device)
static size_t GetOutputTensorNum(const AnfNodePtr &node);
// Get the memory size of output tensor of node.
static size_t GetOutputTensorMemSize(const AnfNodePtr &node, size_t output_index);
// get all outputs format select of anf node
static std::vector<std::string> GetAllOutputFormats(const AnfNodePtr &node);
// get all inputs format select of anf node

@ -16,7 +16,6 @@
#include "runtime/device/kernel_runtime.h"
#include <functional>
#include <numeric>
#include <utility>
#include <vector>
#include "backend/optimizer/common/helper.h"
@ -57,28 +56,6 @@ bool KernelRuntime::NodeOutputDeviceAddressExist(const AnfNodePtr &kernel, size_
return false;
}
size_t KernelRuntime::CountNodeDeviceMemorySize(const mindspore::AnfNodePtr &node, size_t output_index) {
MS_EXCEPTION_IF_NULL(node);
if (output_index >= AnfAlgo::GetOutputTensorNum(node)) {
MS_EXCEPTION(ArgumentError) << "output index [" << output_index << "] large than the output size ["
<< AnfAlgo::GetOutputTensorNum(node) << "] of node!";
}
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(node, output_index);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(node, output_index);
}
size_t type_size = GetTypeByte(TypeIdToType(output_type_id));
std::vector<size_t> shape = AnfAlgo::GetOutputDeviceShape(node, output_index);
auto format = AnfAlgo::GetOutputFormat(node, output_index);
if (shape.empty() && format != kOpFormat_DEFAULT) {
shape = trans::PaddingShape(shape, format, AnfAlgo::GetOutputReshapeType(node, output_index));
shape = trans::TransShapeToDevice(shape, format);
}
// scalar's output shape is a empty vector
size_t tensor_size = std::accumulate(shape.begin(), shape.end(), type_size, std::multiplies<size_t>());
return tensor_size;
}
void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
@ -184,7 +161,7 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(item, index);
}
auto tensor_size = CountNodeDeviceMemorySize(item, index);
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address =
CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_EXCEPTION_IF_NULL(device_address);
@ -361,7 +338,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
continue;
}
#endif
auto tensor_size = CountNodeDeviceMemorySize(item, index);
auto tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
device_address = CreateDeviceAddress(nullptr, tensor_size, AnfAlgo::GetOutputFormat(item, index), output_type_id);
MS_LOG(DEBUG) << "Malloc static memory for " << item->fullname_with_scope();
if (mem_manager_->MallocMem(kStaticMem, tensor_size, device_address, graph->graph_id()) == nullptr) {
@ -656,7 +633,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
continue;
}
size_t tensor_size = tensor->data().nbytes();
auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
auto node_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);

@ -138,7 +138,6 @@ class KernelRuntime {
bool LaunchKernelMod(const session::KernelGraph &graph);
void LaunchKernelEvent(const std::vector<std::vector<std::function<void()>>> &run_events, size_t index);
static void GenAddrCleanLaunchArgs(const CNodePtr &cnode, AddressPtrList *kernel_inputs);
size_t CountNodeDeviceMemorySize(const AnfNodePtr &node, size_t output_index);
void RunOpAssignInputMemory(const std::vector<tensor::TensorPtr> &input_tensors, const session::KernelGraph *graph);
void RunOpAssignOutputMemory(const AnfNodePtr &kernel);
void RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel);

@ -15,29 +15,200 @@
*/
#include "runtime/framework/graph_compiler.h"
#include <numeric>
#include <map>
#include "runtime/framework/graph_scheduler.h"
#include "runtime/device/device_address.h"
#include "common/trans.h"
#include "utils/convert_utils.h"
#include "ir/tensor.h"
namespace mindspore {
namespace runtime {
void GraphCompiler::set_device_context(device::DeviceContext *device_context) {
namespace {
// Whether device address of anf node is valid and device address type
// is consistent with device type, for example, device address type
// DeviceAddressType::kGPU should be used on GPU device
bool NodeDeviceAddressExist(const DeviceContext *device_context, const AnfNodePtr &kernel, size_t index) {
MS_EXCEPTION_IF_NULL(kernel);
MS_EXCEPTION_IF_NULL(device_context);
if (AnfAlgo::OutputAddrExist(kernel, index)) {
const auto &address = AnfAlgo::GetOutputAddr(kernel, index);
MS_EXCEPTION_IF_NULL(address);
return address->DeviceType() == device_context->GetDeviceAddressType();
}
return false;
}
void CreateParameterDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
std::vector<AnfNodePtr> graph_inputs = graph->inputs();
const std::vector<bool> &graph_valid_input = graph->valid_inputs();
graph_inputs.insert(graph_inputs.end(), graph->child_graph_result().begin(), graph->child_graph_result().end());
// Anf nodes which need create device address.
std::vector<AnfNodePtr> nodes_list;
for (size_t i = 0; i < graph_inputs.size(); ++i) {
AnfNodePtr item = graph_inputs[i];
MS_EXCEPTION_IF_NULL(item);
if (i < graph_valid_input.size() && !graph_valid_input[i]) {
continue;
}
if (AnfAlgo::CheckPrimitiveType(item, prim::kPrimMakeTuple)) {
std::vector<AnfNodePtr> outs = AnfAlgo::GetAllOutput(item);
for (const auto &out : outs) {
MS_EXCEPTION_IF_NULL(out);
if (!out->isa<Parameter>() || NodeDeviceAddressExist(device_context, out, 0)) {
continue;
}
nodes_list.push_back(out);
}
}
if (!item->isa<Parameter>() || NodeDeviceAddressExist(device_context, item, 0)) {
continue;
}
nodes_list.push_back(item);
}
// Create device address for anf node in nodes_list
for (const auto &item : nodes_list) {
auto output_size = AnfAlgo::GetOutputTensorNum(item);
for (size_t index = 0; index < output_size; index++) {
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(item, index);
// if graph output is a weight and doesn't link to any cnode, it's data type will be unknown
if (output_type_id == kTypeUnknown) {
MS_LOG(WARNING) << "It is not suggested to use a lonely weight parameter as the output of graph";
continue;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(item, index);
auto device_address = device_context->CreateDeviceAddress(nullptr, tensor_size,
AnfAlgo::GetOutputFormat(item, index), output_type_id);
AnfAlgo::SetOutputAddr(device_address, index, item.get());
}
}
}
void CreateDeviceAddressForTensorValue(const DeviceContext *device_context, const ValuePtr &node_value,
size_t output_idx, const ValueNodePtr &value_node) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(node_value);
MS_EXCEPTION_IF_NULL(value_node);
const auto &ms_context = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(ms_context);
std::vector<tensor::TensorPtr> tensors;
TensorValueToTensor(node_value, &tensors);
for (const auto &tensor : tensors) {
if (tensor == nullptr) {
MS_LOG(WARNING) << "Tensor is null";
return;
}
auto output_address = std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address());
if (output_address != nullptr && output_address->DeviceType() == device_context->GetDeviceAddressType()) {
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
value_node.get());
continue;
}
size_t tensor_size = AnfAlgo::GetOutputTensorMemSize(value_node, output_idx);
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
if (output_type_id == kTypeUnknown) {
output_type_id = AnfAlgo::GetOutputInferDataType(value_node, output_idx);
}
std::string output_format = AnfAlgo::GetOutputFormat(value_node, output_idx);
device::DeviceAddressPtr address =
device_context->CreateDeviceAddress(nullptr, tensor_size, output_format, output_type_id);
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
}
}
void CreateValueNodeDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
for (const ValueNodePtr &value_node : graph->graph_value_nodes()) {
MS_EXCEPTION_IF_NULL(value_node);
if (NodeDeviceAddressExist(device_context, value_node, 0)) {
continue;
}
const auto &node_value = value_node->value();
MS_EXCEPTION_IF_NULL(node_value);
if (node_value->isa<tensor::Tensor>() || node_value->isa<ValueTuple>()) {
CreateDeviceAddressForTensorValue(device_context, node_value, 0, value_node);
} else if (node_value->isa<StringImm>()) {
auto value = GetValue<std::string>(node_value);
size_t tensor_size = value.size();
auto address = device_context->CreateDeviceAddress(nullptr, tensor_size, kOpFormat_DEFAULT, kNumberTypeUInt8);
MS_EXCEPTION_IF_NULL(address);
AnfAlgo::SetOutputAddr(address, 0, value_node.get());
}
}
}
void CreateKernelOutputDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto output_sizes = kernel_mod->GetOutputSizeList();
for (size_t i = 0; i < output_sizes.size(); ++i) {
if (AnfAlgo::OutputAddrExist(kernel, i)) {
continue;
}
std::string output_format = AnfAlgo::GetOutputFormat(kernel, i);
auto output_type = AnfAlgo::GetOutputDeviceDataType(kernel, i);
auto device_address = device_context->CreateDeviceAddress(nullptr, output_sizes[i], output_format, output_type);
AnfAlgo::SetOutputAddr(device_address, i, kernel.get());
}
}
}
void CreateKernelWorkspaceDeviceAddress(const DeviceContext *device_context, const KernelGraphPtr &graph) {
MS_EXCEPTION_IF_NULL(device_context);
MS_EXCEPTION_IF_NULL(graph);
const std::vector<CNodePtr> &kernels = graph->execution_order();
for (const auto &kernel : kernels) {
auto kernel_mod = AnfAlgo::GetKernelMod(kernel);
MS_EXCEPTION_IF_NULL(kernel_mod);
auto workspace_sizes = kernel_mod->GetWorkspaceSizeList();
for (size_t i = 0; i < workspace_sizes.size(); ++i) {
auto device_address = device_context->CreateDeviceAddress(nullptr, workspace_sizes[i], "", kTypeUnknown);
AnfAlgo::SetWorkspaceAddr(device_address, i, kernel.get());
}
}
}
} // namespace
void GraphCompiler::set_device_context(DeviceContext *device_context) {
MS_EXCEPTION_IF_NULL(device_context);
device_context_ = device_context;
// The member variable 'session_' will be removed after removing session module.
if (session_ == nullptr) {
session_ = std::make_shared<session::SessionBasic>();
const device::DeviceContextKey &device_context_key = device_context->device_context_key();
session_->InitExecutor(device_context_key.device_name_, device_context_key.device_id_);
}
}
GraphId GraphCompiler::CompileGraph(const AnfNodePtrList &nodes, const AnfNodePtrList &outputs) {
MS_EXCEPTION_IF_NULL(session_);
// Generate kernel graph.
auto graph = session_->ConstructKernelGraph(nodes, outputs);
KernelGraphPtr graph = session_->ConstructKernelGraph(nodes, outputs);
MS_EXCEPTION_IF_NULL(graph);
return CompileGraphImpl(graph);
}
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(device_context_);
// Optimization pass which is irrelevant to device type or format.
device_context_->OptimizeGraphWithoutDeviceInfo(graph);
@ -51,6 +222,8 @@ GraphId GraphCompiler::CompileGraphImpl(const KernelGraphPtr &graph) {
// 'KernelMod' is real executive object of kernel.
device_context_->CreateKernel(graph->execution_order());
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph);
// Transform graph to actor DAG, contains build and link.
GraphScheduler::GetInstance().Transform(graph, device_context_);
return graph->graph_id();
@ -68,7 +241,7 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph
}
// Generate kernel graph.
MS_EXCEPTION_IF_NULL(session_);
auto graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
KernelGraphPtr graph = session_->ConstructSingleOpGraph(*op_run_info, *input_tensors, tensors_mask);
MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(device_context_);
@ -82,6 +255,8 @@ GraphId GraphCompiler::CompileGraph(session::OpRunInfo *op_run_info, const Graph
// Generate 'KernelMod' for kernel in graph.
device_context_->CreateKernel(graph->execution_order());
// Create device address for all anf nodes of graph.
CreateDeviceAddress(graph);
// Transform graph to actor DAG, contains build and link.
GraphScheduler::GetInstance().Transform(graph, device_context_, input_tensors, GraphExecutionStrategy::kStep);
run_op_graphs_[graph_info] = graph;
@ -101,5 +276,12 @@ KernelGraphPtr GraphCompiler::Fetch(const GraphInfo &graph_info) const {
}
return iter->second;
}
void GraphCompiler::CreateDeviceAddress(const KernelGraphPtr &graph) const {
CreateParameterDeviceAddress(device_context_, graph);
CreateValueNodeDeviceAddress(device_context_, graph);
CreateKernelOutputDeviceAddress(device_context_, graph);
CreateKernelWorkspaceDeviceAddress(device_context_, graph);
}
} // namespace runtime
} // namespace mindspore

@ -26,6 +26,7 @@
namespace mindspore {
namespace runtime {
using device::DeviceContext;
class GraphCompiler {
public:
static GraphCompiler &GetInstance() {
@ -35,7 +36,7 @@ class GraphCompiler {
// Set device context which is initialized, the function must be called
// before using GraphCompiler and after changing device type or device id.
void set_device_context(device::DeviceContext *device_context);
void set_device_context(DeviceContext *device_context);
// Construct kernel graph from anf nodes list and compile kernel graph in Graph mode,
// the detailed implementation of compiling graph is in 'CompileGraphImpl'.
@ -58,9 +59,12 @@ class GraphCompiler {
// The implementation of compiling graph in Graph Mode, including optimizing graph,
// setting operator info, creating kernel and transforming kernel graph to ActorSet.
GraphId CompileGraphImpl(const KernelGraphPtr &graph);
GraphId CompileGraphImpl(const KernelGraphPtr &graph) const;
device::DeviceContext *device_context_{nullptr};
// Create device address for all anf nodes of graph.
void CreateDeviceAddress(const KernelGraphPtr &graph) const;
DeviceContext *device_context_{nullptr};
// Single op kernel graph cache for PyNative mode.
std::unordered_map<GraphInfo, KernelGraphPtr> run_op_graphs_;

@ -50,6 +50,11 @@ void CPUDeviceContext::FreeMemory(DeviceAddress *const &address) const {
address->ptr_ = nullptr;
}
DeviceAddressPtr CPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) const {
return std::make_shared<CPUDeviceAddress>(device_ptr, device_size, format, type_id);
}
void CPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
// Update Graph Dynamic Shape Attr.
UpdateGraphDynamicShapeAttr(NOT_NULL(graph));

@ -18,6 +18,7 @@
#include <vector>
#include <memory>
#include <string>
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/device/memory_manager.h"
@ -36,6 +37,10 @@ class CPUDeviceContext : public DeviceContext {
bool AllocateMemory(DeviceAddress *const &address, size_t size) const override;
void FreeMemory(DeviceAddress *const &address) const override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) const override;
DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kCPU; }
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
void OptimizeSingleOpGraph(const KernelGraphPtr &graph) const override;

@ -63,6 +63,13 @@ class DeviceContext {
return true;
}
// Create concrete device address according different device type.
virtual DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) const = 0;
// Get device address type according different device type, such GPU, Ascend.
virtual DeviceAddressType GetDeviceAddressType() const = 0;
// The two functions below will be merged to one in the future.
// General graph optimezer ignore device data type and format.
virtual void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {}
@ -90,6 +97,9 @@ class DeviceContext {
// Devices that do not need stream could ignore the implementation of this function.
virtual bool SyncStream(size_t stream_id = 0) { return true; }
// Get device_context_key_ to obtain device name and device id.
const DeviceContextKey &device_context_key() const { return device_context_key_; }
protected:
DeviceContextKey device_context_key_;
};

@ -165,6 +165,11 @@ bool GPUDeviceContext::AllocateContinuousMemory(const std::vector<DeviceAddress
return true;
}
DeviceAddressPtr GPUDeviceContext::CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) const {
return std::make_shared<GPUDeviceAddress>(device_ptr, device_size, format, type_id);
}
void GPUDeviceContext::OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const {
MS_EXCEPTION_IF_NULL(graph);
// Operator fusion optimization.

@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include <string>
#include "runtime/hardware/device_context.h"
#include "runtime/hardware/device_context_manager.h"
#include "runtime/device/memory_manager.h"
@ -43,6 +44,10 @@ class GPUDeviceContext : public DeviceContext {
bool AllocateContinuousMemory(const std::vector<DeviceAddress *> &addr_list, size_t total_size,
const std::vector<size_t> &size_list) const override;
DeviceAddressPtr CreateDeviceAddress(void *device_ptr, size_t device_size, const string &format,
TypeId type_id) const override;
DeviceAddressType GetDeviceAddressType() const override { return DeviceAddressType::kGPU; }
// General graph optimezer ignore device data type and format.
void OptimizeGraphWithoutDeviceInfo(const KernelGraphPtr &graph) const override;
// Optimize the kernel graph according to device type, such format transform.

Loading…
Cancel
Save