You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
mindspore/mindspore/ccsrc/session/anf_runtime_algorithm.cc

991 lines
39 KiB

/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "session/anf_runtime_algorithm.h"
#include <memory>
#include <algorithm>
#include <map>
#include <set>
#include "ir/anf.h"
#include "ir/func_graph.h"
#include "operator/ops.h"
#include "utils/utils.h"
#include "device/kernel_info.h"
#include "device/device_address.h"
#include "pre_activate/common/helper.h"
#include "kernel/kernel.h"
#include "kernel/kernel_build_info.h"
#include "common/utils.h"
#include "common/trans.h"
namespace mindspore {
namespace session {
using abstract::AbstractTensor;
using abstract::AbstractTuple;
using device::KernelInfo;
using device::ascend::AscendDeviceAddress;
using kernel::KernelBuildInfoPtr;
using kernel::KernelMod;
using kernel::KernelModPtr;
namespace {
std::vector<size_t> TransShapeToSizet(const abstract::ShapePtr &shape) {
MS_EXCEPTION_IF_NULL(shape);
std::vector<size_t> shape_size_t;
std::transform(shape->shape().begin(), shape->shape().end(), std::back_inserter(shape_size_t), IntToSize);
return shape_size_t;
}
} // namespace
KernelWithIndex AnfRuntimeAlgorithm::VisitKernel(const AnfNodePtr &anf_node, size_t index) {
MS_EXCEPTION_IF_NULL(anf_node);
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimMakeTuple)) {
auto node = cnode->input(index + IntToSize(1));
MS_EXCEPTION_IF_NULL(node);
return VisitKernel(node, 0);
} else if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernel(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx));
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernel(cnode->input(kRealInputIndexInDepend), 0);
} else {
return std::make_pair(anf_node, index);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
}
}
KernelWithIndex AnfRuntimeAlgorithm::VisitKernelWithReturnType(const AnfNodePtr &anf_node, size_t index,
bool visit_nop_node,
const std::vector<PrimitivePtr> &return_types) {
MS_EXCEPTION_IF_NULL(anf_node);
for (const auto &prim_type : return_types) {
if (CheckPrimitiveType(anf_node, prim_type)) {
return std::make_pair(anf_node, index);
}
}
if (anf_node->isa<ValueNode>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<Parameter>()) {
return std::make_pair(anf_node, 0);
} else if (anf_node->isa<CNode>()) {
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input0 = cnode->input(0);
MS_EXCEPTION_IF_NULL(input0);
if (IsPrimitive(input0, prim::kPrimTupleGetItem)) {
if (cnode->inputs().size() != kTupleGetItemInputSize) {
MS_LOG(EXCEPTION) << "The node tuple_get_item must have 2 inputs!";
}
auto input2 = cnode->input(kInputNodeOutputIndexInTupleGetItem);
MS_EXCEPTION_IF_NULL(input2);
auto value_node = input2->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
int item_idx = GetValue<int>(value_node->value());
return VisitKernelWithReturnType(cnode->input(kRealInputNodeIndexInTupleGetItem), IntToSize(item_idx),
visit_nop_node, return_types);
} else if (IsPrimitive(input0, prim::kPrimDepend) || IsPrimitive(input0, prim::kPrimControlDepend)) {
return VisitKernelWithReturnType(cnode->input(kRealInputIndexInDepend), 0, visit_nop_node, return_types);
} else if (opt::IsNopNode(cnode) && visit_nop_node) {
if (cnode->inputs().size() == 2) {
return VisitKernelWithReturnType(cnode->input(1), 0, visit_nop_node, return_types);
} else {
MS_LOG(EXCEPTION) << cnode->DebugString() << "Invalid nop node";
}
} else {
return std::make_pair(anf_node, index);
}
} else {
MS_LOG(EXCEPTION) << "The input is invalid";
}
}
std::vector<AnfNodePtr> AnfRuntimeAlgorithm::GetAllOutput(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types) {
std::vector<AnfNodePtr> ret;
auto return_prim_type = return_types;
// if visited make_tuple should return back
return_prim_type.push_back(prim::kPrimMakeTuple);
auto item_with_index = AnfAlgo::VisitKernelWithReturnType(node, 0, false, return_prim_type);
if (AnfAlgo::CheckPrimitiveType(item_with_index.first, prim::kPrimMakeTuple)) {
MS_EXCEPTION_IF_NULL(item_with_index.first);
auto make_tuple = item_with_index.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(make_tuple);
for (size_t i = 1; i < make_tuple->inputs().size(); i++) {
auto input_i_vector = GetAllOutput(make_tuple->input(i), return_types);
(void)std::copy(input_i_vector.begin(), input_i_vector.end(), std::back_inserter(ret));
}
return ret;
}
ret.push_back(item_with_index.first);
return ret;
}
AnfNodePtr AnfRuntimeAlgorithm::GetCNodePrimitiveNode(const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->input(kAnfPrimitiveIndex);
}
PrimitivePtr AnfRuntimeAlgorithm::GetCNodePrimitive(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto attr_input = GetCNodePrimitiveNode(cnode);
MS_EXCEPTION_IF_NULL(attr_input);
auto value_node = attr_input->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto value = value_node->value();
MS_EXCEPTION_IF_NULL(value);
auto primitive = value->cast<PrimitivePtr>();
return primitive;
}
bool AnfRuntimeAlgorithm::CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
return IsPrimitive(cnode->input(kAnfPrimitiveIndex), primitive_type);
}
std::string AnfRuntimeAlgorithm::GetCNodeName(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<CNode>()) {
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
return primitive->name();
}
MS_LOG(EXCEPTION) << "Unknown anf node type " << node->DebugString();
}
std::string AnfRuntimeAlgorithm::GetNodeDebugString(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->DebugString();
}
void AnfRuntimeAlgorithm::SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
primitive->set_attr(key, value);
}
void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to) {
CopyNodeAttr(key, key, from, to);
}
void AnfRuntimeAlgorithm::CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
const AnfNodePtr &to) {
MS_EXCEPTION_IF_NULL(from);
MS_EXCEPTION_IF_NULL(to);
if (!from->isa<CNode>() || !to->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << " ,to_node is "
<< to->DebugString();
}
auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
MS_EXCEPTION_IF_NULL(from_primitive);
auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
MS_EXCEPTION_IF_NULL(to_primitive);
to_primitive->set_attr(new_key, from_primitive->GetAttr(old_key));
}
void AnfRuntimeAlgorithm::CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to) {
MS_EXCEPTION_IF_NULL(from);
MS_EXCEPTION_IF_NULL(to);
if (!from->isa<CNode>() || !to->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has attr, but this from_anf is " << from->DebugString() << ",to_node is "
<< from->DebugString();
}
auto from_primitive = AnfAlgo::GetCNodePrimitive(from);
MS_EXCEPTION_IF_NULL(from_primitive);
auto to_primitive = AnfAlgo::GetCNodePrimitive(to);
MS_EXCEPTION_IF_NULL(to_primitive);
(void)to_primitive->SetAttrs(from_primitive->attrs());
}
void AnfRuntimeAlgorithm::EraseNodeAttr(const std::string &key, const AnfNodePtr node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node->DebugString();
}
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
primitive->EraseAttr(key);
}
bool AnfRuntimeAlgorithm::HasNodeAttr(const std::string &key, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto primitive = AnfAlgo::GetCNodePrimitive(node);
MS_EXCEPTION_IF_NULL(primitive);
return primitive->HasAttr(key);
}
size_t AnfRuntimeAlgorithm::GetInputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Only cnode has real input, but this anf is " << node->DebugString();
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
size_t input_num = cnode->inputs().size();
if (input_num == 0) {
MS_LOG(EXCEPTION) << "cnode inputs size can't be zero";
}
// exclude intputs[0],which is value_node storing attr,inputs left are real input
return input_num - 1;
}
size_t AnfRuntimeAlgorithm::GetOutputTensorNum(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type = node->Type();
if (type == nullptr) {
return 0;
}
if (type->isa<Tuple>()) {
auto tuple_type = type->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_type);
return tuple_type->size();
} else if (type->isa<TensorType>() || type->isa<Number>()) {
return 1;
} else if (type->isa<TypeNone>()) {
return 0;
} else {
return 1;
}
}
std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "Output index:" << output_idx
<< " is out of the node output range :" << GetOutputTensorNum(node) << " #node ["
<< node->DebugString() << "]";
}
if (!AnfAlgo::IsRealKernel(node)) {
return AnfAlgo::GetPrevNodeOutputFormat(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto format = build_info->GetOutputFormat(output_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid output format";
}
return format;
}
std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "Input index :" << input_idx
<< " is out of the number node Input range :" << GetInputTensorNum(node) << "#node ["
<< node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
GetPrevNodeOutputFormat(node, input_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto format = build_info->GetInputFormat(input_idx);
if (format == kernel::KernelBuildInfo::kInvalidFormat) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid input format";
}
return format;
}
KernelWithIndex AnfRuntimeAlgorithm::GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(anf_node);
if (!anf_node->isa<CNode>()) {
MS_LOG(EXCEPTION) << anf_node->DebugString() << "anf_node is not CNode.";
}
auto cnode = anf_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (input_idx + 1 >= cnode->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index " << input_idx << " is larger than input number " << GetInputTensorNum(cnode);
}
auto node = cnode->input(input_idx + 1);
MS_EXCEPTION_IF_NULL(node);
return VisitKernel(node, 0);
}
std::string AnfRuntimeAlgorithm::GetPrevNodeOutputFormat(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputFormat(kernel_with_index.first, kernel_with_index.second);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetPrevNodeOutputReshapeType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return GetOutputReshapeType(kernel_with_index.first, kernel_with_index.second);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputInferShape(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
abstract::BaseShapePtr base_shape = node->Shape();
MS_EXCEPTION_IF_NULL(base_shape);
if (base_shape->isa<abstract::Shape>() && output_idx == 0) {
return TransShapeToSizet(base_shape->cast<abstract::ShapePtr>());
} else if (base_shape->isa<abstract::TupleShape>()) {
auto tuple_shape = base_shape->cast<abstract::TupleShapePtr>();
MS_EXCEPTION_IF_NULL(tuple_shape);
if (output_idx >= tuple_shape->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << "is larger than output number " << tuple_shape->size()
<< ".";
}
auto b_shp = (*tuple_shape)[output_idx];
if (b_shp->isa<abstract::Shape>()) {
return TransShapeToSizet(b_shp->cast<abstract::ShapePtr>());
} else if (b_shp->isa<abstract::NoShape>()) {
return std::vector<size_t>();
} else {
MS_LOG(EXCEPTION) << "The output type of ApplyKernel index:" << output_idx
<< " should be a NoShape , ArrayShape or a TupleShape, but it is " << base_shape->ToString();
}
} else if (base_shape->isa<abstract::NoShape>()) {
return std::vector<size_t>();
}
MS_LOG(EXCEPTION) << "The output type of ApplyKernel should be a NoShape , ArrayShape or a TupleShape, but it is "
<< base_shape->ToString();
}
std::vector<size_t> AnfRuntimeAlgorithm::GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferShape(kernel_with_index.first, kernel_with_index.second);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx) {
auto format = GetOutputFormat(node, output_idx);
auto infer_shape = GetOutputInferShape(node, output_idx);
if (infer_shape.empty()) {
return infer_shape;
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetOutputReshapeType(node, output_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<size_t> AnfRuntimeAlgorithm::GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx) {
auto format = GetInputFormat(node, input_idx);
auto infer_shape = GetPrevNodeOutputInferShape(node, input_idx);
if (infer_shape.empty()) {
return infer_shape;
}
// if format is default_format or NC1KHKWHWC0,device shape = original shape
if (trans::IsNeedPadding(format, infer_shape.size())) {
infer_shape = trans::PaddingShapeTo4d(infer_shape, GetInputReshapeType(node, input_idx));
}
return trans::TransShapeToDevice(infer_shape, format);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index:" << input_idx
<< " is out of range of the node's input size : " << GetInputTensorNum(node) << "#node["
<< node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, input_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
if (build_info->IsInputDefaultPadding()) {
return {};
}
return build_info->GetInputReshapeType(input_idx);
}
std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node[ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputReshapeType(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
if (build_info->IsOutputDefaultPadding()) {
return {};
}
return build_info->GetOutputReshapeType(output_idx);
}
TypeId AnfRuntimeAlgorithm::GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
TypePtr type_ptr = node->Type();
MS_EXCEPTION_IF_NULL(type_ptr);
if (type_ptr->isa<TensorType>() && output_idx == 0) {
auto tensor_ptr = type_ptr->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
} else if (type_ptr->isa<Tuple>()) {
auto tuple_ptr = type_ptr->cast<TuplePtr>();
MS_EXCEPTION_IF_NULL(tuple_ptr);
if (output_idx >= tuple_ptr->size()) {
MS_LOG(EXCEPTION) << "Output index " << output_idx << " must be less than output number " << tuple_ptr->size();
}
auto tuple_i = (*tuple_ptr)[output_idx];
MS_EXCEPTION_IF_NULL(tuple_i);
if (tuple_i->isa<TensorType>()) {
auto tensor_ptr = tuple_i->cast<TensorTypePtr>();
MS_EXCEPTION_IF_NULL(tensor_ptr);
TypePtr elem = tensor_ptr->element();
MS_EXCEPTION_IF_NULL(elem);
return elem->type_id();
} else if (tuple_i->isa<Number>()) {
return tuple_i->type_id();
} else {
MS_LOG(WARNING) << "Not support type " << tuple_i->ToString();
return tuple_i->type_id();
}
} else if (type_ptr->isa<Number>()) {
return type_ptr->type_id();
}
return type_ptr->type_id();
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(node, input_idx);
return AnfRuntimeAlgorithm::GetOutputInferDataType(kernel_with_index.first, kernel_with_index.second);
}
TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, output_idx);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto dtype = build_info->GetOutputDeviceType(output_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx) {
MS_EXCEPTION_IF_NULL(node);
if (input_idx > GetInputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << input_idx << "] is out of range of the node's input size [ "
<< GetInputTensorNum(node) << "#node [ " << node->DebugString() << "]";
}
if (!IsRealKernel(node)) {
return GetPrevNodeOutputDeviceDataType(node, 0);
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
auto dtype = build_info->GetInputDeviceType(input_idx);
if (dtype == TypeId::kNumberTypeEnd) {
MS_LOG(EXCEPTION) << "Node [" << node->DebugString() << "]"
<< " has a invalid dtype";
}
return dtype;
}
TypeId AnfRuntimeAlgorithm::GetPrevNodeOutputDeviceDataType(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputDeviceDataType(kernel_with_index.first, kernel_with_index.second);
}
// get output device addr of anf_node
const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (opt::IsNopNode(node)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
return AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node";
}
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetOutputAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
<< " output addr is not exist";
}
return addr;
}
DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (opt::IsNopNode(node)) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().size() == 2) {
return AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(cnode, 0);
} else {
MS_LOG(EXCEPTION) << node->DebugString() << "Invalid nop node.";
}
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetMutableOutputAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "Output_idx" << output_idx << " of node " << node->DebugString()
<< " output addr is not exist";
}
return addr;
}
// get output device addr of anf_node
bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
if (output_idx > GetOutputTensorNum(node)) {
MS_LOG(EXCEPTION) << "The index [" << output_idx << "] is out of range of the node's output size [ "
<< GetOutputTensorNum(node) << "#node:[ " << node->DebugString() << "]";
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->OutputAddrExist(output_idx);
}
const DeviceAddress *AnfRuntimeAlgorithm::GetPrevNodeOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetOutputAddr(kernel_with_index.first, kernel_with_index.second);
}
DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx) {
KernelWithIndex kernel_with_index = AnfAlgo::GetPrevNodeOutput(anf_node, input_idx);
return AnfRuntimeAlgorithm::GetMutableOutputAddr(kernel_with_index.first, kernel_with_index.second);
}
// set output device addr of anf_node
void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetOutputAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
}
}
// set workspace device addr of anf_node
void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
if (!kernel_info->SetWorkspaceAddr(addr, output_idx)) {
MS_LOG(EXCEPTION) << "Node " << node->DebugString() << "set adr" << output_idx << " fail";
}
}
// get workspace device addr of anf_node
DeviceAddress *AnfRuntimeAlgorithm::GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto addr = kernel_info->GetWorkspaceAddr(output_idx);
if (addr == nullptr) {
MS_LOG(EXCEPTION) << "Output_idx " << output_idx << " of node " << node->DebugString()
<< "] workspace addr is not exist";
}
return addr;
}
// set infer shapes and types of anf node
void AnfRuntimeAlgorithm::SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
const std::vector<std::vector<size_t>> &shapes, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
if (types.size() != shapes.size()) {
MS_LOG(EXCEPTION) << "Types size " << types.size() << "should be same with shapes size " << shapes.size();
}
if (shapes.empty()) {
MS_LOG(EXCEPTION) << "Illegal empty output_types_shapes";
} else if (shapes.size() == 1) {
// single output handle
std::vector<int> shape_int;
std::transform(shapes[0].begin(), shapes[0].end(), std::back_inserter(shape_int), SizeToInt);
auto abstract = std::make_shared<AbstractTensor>(TypeIdToType(types[0]), shape_int);
node->set_abstract(abstract);
} else {
// multiple output handle
std::vector<AbstractBasePtr> abstract_list;
for (size_t i = 0; i < types.size(); ++i) {
std::vector<int> shape_int;
std::transform(shapes[i].begin(), shapes[i].end(), std::back_inserter(shape_int), SizeToInt);
abstract_list.push_back(std::make_shared<AbstractTensor>(TypeIdToType(types[i]), shape_int));
}
auto abstract_tuple = std::make_shared<AbstractTuple>(abstract_list);
node->set_abstract(abstract_tuple);
}
}
// copy an abstract of a node to another node
void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node) {
to_node->set_abstract(from_node->abstract());
}
kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->op_pattern();
}
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
// select_kernel_build_info() has checked whether return pointer is null
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->kernel_type();
}
kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->processor();
}
kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
auto build_info = kernel_info->select_kernel_build_info();
MS_EXCEPTION_IF_NULL(build_info);
return build_info->fusion_type();
}
// set select kernel_build_info
void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->set_select_kernel_build_info(select_kernel_build_info);
}
// get select kernel_build_info
KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->GetMutableSelectKernelBuildInfo();
}
// get kernelMode
KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->MutableKernelMod();
}
// set kernel mod
void AnfRuntimeAlgorithm::SetKernelMod(const KernelModPtr &kernel_mod, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_kernel_mod(kernel_mod);
}
bool AnfRuntimeAlgorithm::IsRealKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real kernel too
if (!node->isa<CNode>()) {
return true;
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
if (cnode->inputs().empty()) {
MS_LOG(EXCEPTION) << "Illegal null input of cnode(%s)" << node->DebugString();
}
auto input = cnode->inputs()[0];
bool is_virtual_node = IsPrimitive(input, prim::kPrimImageSummary) || IsPrimitive(input, prim::kPrimScalarSummary) ||
IsPrimitive(input, prim::kPrimTensorSummary) ||
IsPrimitive(input, prim::kPrimHistogramSummary) || IsPrimitive(input, prim::kPrimMakeTuple) ||
IsPrimitive(input, prim::kPrimStateSetItem) || IsPrimitive(input, prim::kPrimDepend) ||
IsPrimitive(input, prim::kPrimTupleGetItem) || IsPrimitive(input, prim::kPrimControlDepend) ||
IsPrimitive(input, prim::kPrimReturn);
return !is_virtual_node;
}
bool AnfRuntimeAlgorithm::IsRealCNodeKernel(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
// parameter and value node is not a real cnode kernel
if (!node->isa<CNode>()) {
return false;
}
// return considered as a real node
if (CheckPrimitiveType(node, prim::kPrimReturn)) {
return true;
}
return IsRealKernel(node);
}
bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
MS_EXCEPTION_IF_NULL(node);
return node->has_default();
}
void AnfRuntimeAlgorithm::SetStreamId(uint32_t stream_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_id(stream_id);
}
uint32_t AnfRuntimeAlgorithm::GetStreamId(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_id();
}
void AnfRuntimeAlgorithm::SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_stream_distinction_label(stream_label);
}
uint32_t AnfRuntimeAlgorithm::GetStreamDistinctionLabel(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->stream_distinction_label();
}
void AnfRuntimeAlgorithm::SetGraphId(uint32_t graph_id, AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
kernel_info->set_graph_id(graph_id);
}
uint32_t AnfRuntimeAlgorithm::GetGraphId(const AnfNode *node) {
MS_EXCEPTION_IF_NULL(node);
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->graph_id();
}
bool AnfRuntimeAlgorithm::IsTupleOutput(const AnfNodePtr &anf) {
MS_EXCEPTION_IF_NULL(anf);
TypePtr type = anf->Type();
MS_EXCEPTION_IF_NULL(type);
return type->isa<Tuple>();
}
AnfNodePtr AnfRuntimeAlgorithm::GetInputNode(const CNodePtr &node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
auto get_input_index = index + 1;
if (index + 1 > node->inputs().size()) {
MS_LOG(EXCEPTION) << "Input index size " << get_input_index << "but the node input size just"
<< node->inputs().size();
}
// input 0 is primitive node
return node->input(get_input_index);
}
bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (node->isa<ValueNode>()) {
return false;
}
auto kernel_info = node->kernel_info();
MS_EXCEPTION_IF_NULL(kernel_info);
return kernel_info->is_feature_map();
}
bool AnfRuntimeAlgorithm::IsFeatureMapInput(const AnfNodePtr &node, size_t input_index) {
if (!node->isa<CNode>()) {
MS_LOG(EXCEPTION) << "Cannot input a parameter or a valuenode to charge it's input if is a feature map";
}
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto input_node = cnode->input(input_index + 1);
return IsFeatureMapOutput(input_node);
}
size_t AnfRuntimeAlgorithm::GetRealInputIndex(const mindspore::AnfNodePtr &anf_node, const size_t cur_index) {
MS_EXCEPTION_IF_NULL(anf_node);
static std::map<std::string, std::map<size_t, size_t>> spec_node_list = {
{prim::kPrimConv2DBackpropInput->name(), {{0, 1}, {1, 0}}},
{kFusionOpConv2DBackpropInputReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}}},
{kFusionOpConv2DBackpropInputAddNReluGradV2Name, {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
{prim::kPrimConv2DBackpropFilter->name(), {{0, 1}, {1, 0}}},
{prim::kPrimLogSoftmaxGrad->name(), {{0, 1}, {1, 0}}},
{prim::kPrimLayerNormGrad->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
{prim::kPrimLayerNormBetaGammaBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}}},
{prim::kPrimLayerNormXBackprop->name(), {{0, 1}, {1, 0}, {2, 2}, {3, 3}, {4, 4}}},
{prim::kPrimMinimumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}},
{prim::kPrimMaximumGrad->name(), {{0, 2}, {1, 0}, {2, 1}}}};
size_t ret = cur_index;
auto node_name = AnfAlgo::GetCNodeName(anf_node);
if (AnfAlgo::GetKernelType(anf_node) == TBE_KERNEL) {
auto find = spec_node_list.find(node_name);
if (find != spec_node_list.end()) {
ret = find->second[cur_index];
MS_LOG(INFO) << "Real input index change to" << ret << ", node name:" << node_name;
}
}
return ret;
}
void AnfRuntimeAlgorithm::SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(input_node);
node->set_input(index + 1, input_node);
}
bool AnfRuntimeAlgorithm::IsCommunicationOp(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
if (!node->isa<CNode>()) {
return false;
}
auto kernel_name = AnfAlgo::GetCNodeName(node);
if (kernel_name == kAllReduceOpName || kernel_name == kAllGatherOpName || kernel_name == kBroadcastOpName ||
kernel_name == kReduceScatterOpName) {
return true;
}
return false;
}
bool AnfRuntimeAlgorithm::IsGetNext(const NotNull<AnfNodePtr> &node) {
auto kernel_name = AnfAlgo::GetCNodeName(node);
return kernel_name == kGetNextOpName;
}
FuncGraphPtr AnfRuntimeAlgorithm::GetValueNodeFuncGraph(const AnfNodePtr &node) {
MS_EXCEPTION_IF_NULL(node);
auto value_node = node->cast<ValueNodePtr>();
if (value_node == nullptr) {
return nullptr;
}
auto value = value_node->value();
if (value == nullptr) {
return nullptr;
}
auto func_graph = value->cast<FuncGraphPtr>();
return func_graph;
}
std::vector<KernelGraphPtr> AnfRuntimeAlgorithm::GetCallNodeKernelGraph(const CNodePtr &call_node) {
if (!AnfAlgo::CheckPrimitiveType(call_node, std::make_shared<Primitive>("call"))) {
MS_LOG(EXCEPTION) << "anf node: " << call_node->DebugString() << "is not a call node.";
}
MS_EXCEPTION_IF_NULL(call_node);
auto input1 = call_node->input(1);
MS_EXCEPTION_IF_NULL(input1);
if (input1->isa<ValueNode>()) {
auto value_node = input1->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(value_node);
auto kernel_graph = value_node->value();
MS_EXCEPTION_IF_NULL(kernel_graph);
return {kernel_graph->cast<KernelGraphPtr>()};
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
auto switch_node = input1->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(switch_node);
auto get_switch_kernel_graph = [&](size_t input_index) -> KernelGraphPtr {
auto partial = switch_node->input(input_index);
MS_EXCEPTION_IF_NULL(partial);
auto partial_cnode = partial->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(partial_cnode);
auto graph_node = partial_cnode->input(1);
MS_EXCEPTION_IF_NULL(graph_node);
auto graph_value_node = graph_node->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(graph_value_node);
auto graph_value = graph_value_node->value();
MS_EXCEPTION_IF_NULL(graph_value);
auto child_graph = graph_value->cast<KernelGraphPtr>();
return child_graph;
};
return {get_switch_kernel_graph(2), get_switch_kernel_graph(3)};
}
return {};
}
bool AnfRuntimeAlgorithm::IsSwitchCall(const CNodePtr &call_node) {
MS_EXCEPTION_IF_NULL(call_node);
if (!CheckPrimitiveType(call_node, prim::kPrimCall)) {
MS_LOG(EXCEPTION) << "call node should be a 'call', but is a " << call_node->DebugString();
}
auto input1 = call_node->input(1);
if (input1->isa<ValueNode>()) {
return false;
} else if (input1->isa<CNode>() && AnfAlgo::CheckPrimitiveType(input1, prim::kPrimSwitch)) {
return true;
}
MS_LOG(EXCEPTION) << "Unexpected input1 of call node,input1:" << input1->DebugString();
}
bool AnfRuntimeAlgorithm::IsWhileTrueGraph(const KernelGraphPtr &child_graph) {
auto call_nodes = child_graph->FindNodeByPrimitive(prim::kPrimCall);
for (const auto &call_node : call_nodes) {
auto graphs = GetCallNodeKernelGraph(call_node);
if (graphs.size() == 1 && graphs[0] == child_graph->parent_graph()) {
return true;
}
}
return false;
}
} // namespace session
} // namespace mindspore