You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

186 lines
10 KiB

* Copyright 2019 Huawei Technologies Co., Ltd
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* See the License for the specific language governing permissions and
* limitations under the License.
#include <iostream>
#include <string>
#include <vector>
#include <set>
#include <tuple>
#include <utility>
#include <memory>
#include "ir/anf.h"
#include "ir/dtype.h"
#include "ir/base.h"
#include "ir/primitive.h"
#include "device/device_address.h"
#include "kernel/kernel.h"
#include "kernel/kernel_build_info.h"
#include "operator/ops.h"
#include "utils/contract.h"
namespace mindspore {
namespace session {
using AnfVisitFuncion = std::function<Any(const AnfNodePtr &node, int index)>;
using KernelWithIndex = std::pair<AnfNodePtr, size_t>;
class AnfRuntimeAlgorithm {
// get input_anf_node's real kernel by recurse
static KernelWithIndex VisitKernel(const AnfNodePtr &input_anf_node, size_t output_index);
static KernelWithIndex VisitKernelWithReturnType(const AnfNodePtr &input_anf_node, size_t output_index,
bool visit_nop_node = false,
const std::vector<PrimitivePtr> &return_types = {
static std::vector<AnfNodePtr> GetAllOutput(const AnfNodePtr &node,
const std::vector<PrimitivePtr> &return_types = {});
// get cnode primitive
static AnfNodePtr GetCNodePrimitiveNode(const CNodePtr &node);
static void SetNodeInput(const CNodePtr &node, const AnfNodePtr &input_node, size_t index);
static PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
// check whether anf node is a node of 'primitive_type',such as make_tuple is a cnode of kPrimMakeTuple
static bool CheckPrimitiveType(const AnfNodePtr &node, const PrimitivePtr &primitive_type);
// get kernel_name of anf node
static std::string GetCNodeName(const AnfNodePtr &node);
// get detail info of anf node
static std::string GetNodeDebugString(const AnfNodePtr &node);
// get attr of anf node
template <typename T>
static T GetNodeAttr(const AnfNodePtr &node, const std::string &key) {
if (!node->isa<CNode>()) {
std::string node_debug_log = node->DebugString();
MS_LOG(EXCEPTION) << "Only cnode has attr, but this anf is " << node_debug_log.c_str();
auto primitive = GetCNodePrimitive(node);
return GetValue<T>(primitive->GetAttr(key));
static bool IsTupleOutput(const AnfNodePtr &anf);
// set attr of anf node
static void SetNodeAttr(const std::string &key, const ValuePtr &value, const AnfNodePtr &node);
// set attr of key from 'from' node to 'to' node
static void CopyNodeAttr(const std::string &key, const AnfNodePtr &from, const AnfNodePtr &to);
// set a new key for attr from 'from' node to 'to' node
static void CopyNodeAttr(const std::string &old_key, const std::string &new_key, const AnfNodePtr &from,
const AnfNodePtr &to);
// set all attrs from 'from' node to 'to' node
static void CopyNodeAttrs(const AnfNodePtr &from, const AnfNodePtr &to);
// check whether a cnode has the specified attr.
static bool HasNodeAttr(const std::string &key, const AnfNodePtr &node);
// delete attr of anf node
static void EraseNodeAttr(const std::string &key, AnfNodePtr node);
// get the num of input real_kernel(which can be build and run in device)
static size_t GetInputTensorNum(const AnfNodePtr &node);
// get the num of output real_kernel(which can be build and run in device)
static size_t GetOutputTensorNum(const AnfNodePtr &node);
// get output format select of anf node
static std::string GetOutputFormat(const AnfNodePtr &node, size_t output_idx);
// get input format select of anf node
static std::string GetInputFormat(const AnfNodePtr &node, size_t input_idx);
// get prev node output width output index
static KernelWithIndex GetPrevNodeOutput(const AnfNodePtr &anf_node, size_t input_idx);
// get output format from prev node,input_index is the input index of current node related to prev node
static std::string GetPrevNodeOutputFormat(const AnfNodePtr &node, size_t input_idx);
// get output shapes inferred by ME from input nodes.
static std::vector<size_t> GetOutputInferShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes inferred by ME from input nodes.
static std::vector<size_t> GetPrevNodeOutputInferShape(const AnfNodePtr &node, size_t input_idx);
// get output shapes which will built and run in device
static std::vector<size_t> GetOutputDeviceShape(const AnfNodePtr &node, size_t output_idx);
// get input shapes which will built and run in device
static std::vector<size_t> GetInputDeviceShape(const AnfNodePtr &node, size_t input_idx);
// Get Input Padding Axis
static std::vector<kernel::Axis> GetInputReshapeType(const AnfNodePtr &node, size_t output_idx);
// Get Output Padding Axis
static std::vector<kernel::Axis> GetOutputReshapeType(const AnfNodePtr &node, size_t output_idx);
// get output data type inferred by ME of anf node
static TypeId GetOutputInferDataType(const AnfNodePtr &node, size_t output_idx);
// get output original data type from prev node,input_index is the input index of current node related to prev node
static TypeId GetPrevNodeOutputInferDataType(const AnfNodePtr &node, size_t input_idx);
// get output select data type of anf node
static TypeId GetOutputDeviceDataType(const AnfNodePtr &node, size_t output_idx);
// get input select data type of anf node
static TypeId GetInputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
// get output select data type from prev node,input_index is the input index of current node related to prev node
static TypeId GetPrevNodeOutputDeviceDataType(const AnfNodePtr &node, size_t input_idx);
// get output device addr of anf_node
static const DeviceAddress *GetOutputAddr(const AnfNodePtr &node, size_t output_idx);
// get mutable output device addr of anf_node
static DeviceAddressPtr GetMutableOutputAddr(const AnfNodePtr &node, size_t output_idx);
// check whether output addr is exist or not
static bool OutputAddrExist(const AnfNodePtr &node, size_t output_idx);
// get address from prev node,input_index is the input index of current node related to prev node
static const DeviceAddress *GetPrevNodeOutputAddr(const AnfNodePtr &node, size_t input_idx);
static DeviceAddressPtr GetPrevNodeMutableOutputAddr(const AnfNodePtr &anf_node, size_t input_idx);
// set output device addr of anf_node
static void SetOutputAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
// set workspace device addr of anf_node
static void SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t output_idx, AnfNode *node);
// get workspace device addr of anf_node
static DeviceAddress *GetWorkspaceAddr(const AnfNodePtr &node, size_t output_idx);
// set infer shapes and types of anf node
static void SetOutputInferTypeAndShape(const std::vector<TypeId> &types,
const std::vector<std::vector<size_t>> &shapes, AnfNode *node);
static void CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_node);
// get KernelBuildType of node ,such as ATT,RT,FWK and so on
static KernelType GetKernelType(const AnfNodePtr &node);
// get processor type:AICORE,AICPU...
static kernel::Processor GetProcessor(const AnfNodePtr &node);
// get fusion type:AICORE,AICPU...
static kernel::FusionType GetFusionType(const AnfNodePtr &node);
// set select kernel_build_info
static void SetSelectKernelBuildInfo(const kernel::KernelBuildInfoPtr &select_kernel_build_info, AnfNode *node);
// get select kernel_build_info
static kernel::KernelBuildInfoPtr GetSelectKernelBuildInfo(const AnfNodePtr &node);
// get kernelMode
static kernel::KernelMod *GetKernelMod(const AnfNodePtr &node);
// set kernel mod
static void SetKernelMod(const kernel::KernelModPtr &kernel_mod, AnfNode *node);
// checkout whether the anf node is a real kernel that can run on device,parameter and constant is real kernel too
static bool IsRealKernel(const AnfNodePtr &node);
// checkout whether the anf node is a real kernel that is a cnode and can run on device
static bool IsRealCNodeKernel(const AnfNodePtr &node);
// check parameter is weight or data
static bool IsParameterWeight(const ParameterPtr &node);
// set stream id of kernel,which will be set in stream assign and be used in stream generate
static void SetStreamId(uint32_t stream_id, AnfNode *node);
// get stream id
static uint32_t GetStreamId(const AnfNodePtr &node);
// set stream distinction label to distinguish different ops in different streams
static void SetStreamDistinctionLabel(uint32_t stream_label, AnfNode *node);
// get stream distinction label
static uint32_t GetStreamDistinctionLabel(const AnfNode *node);
// set graph id
static void SetGraphId(uint32_t graph_id, AnfNode *node);
// get graph id
static uint32_t GetGraphId(const AnfNode *node);
static AnfNodePtr GetInputNode(const CNodePtr &node, size_t index);
// charge if the node's output is a feature map output
static bool IsFeatureMapOutput(const AnfNodePtr &node);
// charge if the node's input is from a feature map output
static bool IsFeatureMapInput(const AnfNodePtr &node, size_t input_index);
// get real input index for some tbe ops which input order is different between me and tbe impl
static size_t GetRealInputIndex(const AnfNodePtr &anf_node, const size_t cur_index);
static bool IsCommunicationOp(const AnfNodePtr &node);
static bool IsAllReduceOp(const AnfNodePtr &node);
static bool IsGetNext(const NotNull<AnfNodePtr> &node);
} // namespace session
using AnfAlgo = session::AnfRuntimeAlgorithm;
} // namespace mindspore