|
|
|
@ -29,6 +29,7 @@
|
|
|
|
|
#include "backend/session/anf_runtime_algorithm.h"
|
|
|
|
|
#include "backend/kernel_compiler/common_utils.h"
|
|
|
|
|
#include "backend/kernel_compiler/oplib/oplib.h"
|
|
|
|
|
#include "backend/optimizer/common/helper.h"
|
|
|
|
|
#include "ir/value.h"
|
|
|
|
|
using mindspore::kernel::Address;
|
|
|
|
|
using mindspore::kernel::AddressPtr;
|
|
|
|
@ -150,11 +151,13 @@ void KernelRuntime::AssignMemory(session::KernelGraph *graph) {
|
|
|
|
|
UpdateRefNodeOutputMem(graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::RunOpAssignMemory(const std::vector<tensor::TensorPtr> &input_tensors,
|
|
|
|
|
void KernelRuntime::RunOpAssignMemory(const ValuePtr &pre_output_value,
|
|
|
|
|
const std::vector<tensor::TensorPtr> &input_tensors,
|
|
|
|
|
session::KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
RunOpAssignInputMemory(input_tensors, graph);
|
|
|
|
|
AssignStaticMemoryValueNode(graph);
|
|
|
|
|
RunOpAssignOutputNodeMemory(pre_output_value, graph);
|
|
|
|
|
for (const auto &cnode : graph->execution_order()) {
|
|
|
|
|
RunOpAssignOutputMemory(cnode);
|
|
|
|
|
RunOpAssignWorkSpaceMemory(cnode);
|
|
|
|
@ -322,6 +325,45 @@ void KernelRuntime::RunOpAssignWorkSpaceMemory(const AnfNodePtr &kernel) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::RunOpAssignOutputNodeMemory(const ValuePtr &pre_output_value, session::KernelGraph *graph) {
|
|
|
|
|
if (pre_output_value == nullptr) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
std::vector<tensor::TensorPtr> pre_output_tensors;
|
|
|
|
|
TensorValueToTensor(pre_output_value, &pre_output_tensors);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
auto output_nodes = graph->outputs();
|
|
|
|
|
if (pre_output_tensors.size() != output_nodes.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The size of pre output tensors [" << pre_output_tensors.size()
|
|
|
|
|
<< "] is not equal to the size of output nodes of graph [" << output_nodes.size() << "]";
|
|
|
|
|
}
|
|
|
|
|
// share output address with pre output tensors
|
|
|
|
|
for (size_t i = 0; i < output_nodes.size(); ++i) {
|
|
|
|
|
auto output_node_with_index = AnfAlgo::VisitKernel(output_nodes[i], 0);
|
|
|
|
|
if (!output_node_with_index.first->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The output node should be a cnode , but it is "
|
|
|
|
|
<< output_node_with_index.first->DebugString();
|
|
|
|
|
}
|
|
|
|
|
auto real_output_cnode = output_node_with_index.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(real_output_cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_output_tensors[i]);
|
|
|
|
|
if (pre_output_tensors[i]->device_address() == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The address of pre output tensor [" << i << "] is a nullptr!";
|
|
|
|
|
}
|
|
|
|
|
if (opt::IsNopNode(real_output_cnode)) {
|
|
|
|
|
if (real_output_cnode->inputs().size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "The input size of output node: " << real_output_cnode->DebugString()
|
|
|
|
|
<< " should large than one!";
|
|
|
|
|
}
|
|
|
|
|
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
|
|
|
|
|
output_node_with_index.second, real_output_cnode->input(1).get());
|
|
|
|
|
} else {
|
|
|
|
|
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(pre_output_tensors[i]->device_address()),
|
|
|
|
|
output_node_with_index.second, output_node_with_index.first.get());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
|
|
@ -573,11 +615,18 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|
|
|
|
MS_EXCEPTION_IF_NULL(mem_manager_);
|
|
|
|
|
auto ms_context = MsContext::GetInstance();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ms_context);
|
|
|
|
|
auto tensor = node_value->cast<TensorPtr>();
|
|
|
|
|
std::vector<tensor::TensorPtr> tensors;
|
|
|
|
|
TensorValueToTensor(node_value, &tensors);
|
|
|
|
|
for (const auto &tensor : tensors) {
|
|
|
|
|
if (tensor == nullptr) {
|
|
|
|
|
MS_LOG(WARNING) << "Tensor is null";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (tensor->device_address() != nullptr) {
|
|
|
|
|
AnfAlgo::SetOutputAddr(std::dynamic_pointer_cast<device::DeviceAddress>(tensor->device_address()), output_idx++,
|
|
|
|
|
value_node.get());
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
size_t tensor_size = tensor->data().nbytes();
|
|
|
|
|
auto node_size = CountNodeDeviceMemorySize(value_node, output_idx);
|
|
|
|
|
TypeId output_type_id = AnfAlgo::GetOutputDeviceDataType(value_node, output_idx);
|
|
|
|
@ -596,9 +645,10 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
|
|
|
|
|
AnfAlgo::SetOutputAddr(address, output_idx, value_node.get());
|
|
|
|
|
if (!address->SyncHostToDevice(trans::GetRuntimePaddingShape(value_node, 0), tensor_size, tensor->data_type(),
|
|
|
|
|
tensor->data_c())) {
|
|
|
|
|
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString() << "node format is"
|
|
|
|
|
<< AnfAlgo::GetOutputFormat(value_node, output_idx) << "node dtype is "
|
|
|
|
|
<< AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
|
|
|
|
MS_EXCEPTION(NotExistsError) << "ValueNode SyncHostToDevice fail!" << value_node->DebugString()
|
|
|
|
|
<< "node format is" << AnfAlgo::GetOutputFormat(value_node, output_idx)
|
|
|
|
|
<< "node dtype is " << AnfAlgo::GetOutputInferDataType(value_node, output_idx);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -615,7 +665,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
|
|
|
|
|
}
|
|
|
|
|
auto &node_value = value_node->value();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_value);
|
|
|
|
|
if (node_value->isa<Tensor>()) {
|
|
|
|
|
if (node_value->isa<Tensor>() || node_value->isa<ValueTuple>()) {
|
|
|
|
|
AssignValueNodeTensor(value_node, node_value, 0);
|
|
|
|
|
} else if (node_value->isa<StringImm>()) {
|
|
|
|
|
auto value = GetValue<std::string>(node_value);
|
|
|
|
|