You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
513 lines
18 KiB
513 lines
18 KiB
/**
|
|
* This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
|
|
*
|
|
* Copyright 2019 Huawei Technologies Co., Ltd
|
|
*
|
|
* Licensed under the Apache License, Version 2.0 (the "License");
|
|
* you may not use this file except in compliance with the License.
|
|
* You may obtain a copy of the License at
|
|
*
|
|
* http://www.apache.org/licenses/LICENSE-2.0
|
|
*
|
|
* Unless required by applicable law or agreed to in writing, software
|
|
* distributed under the License is distributed on an "AS IS" BASIS,
|
|
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
* See the License for the specific language governing permissions and
|
|
* limitations under the License.
|
|
*/
|
|
|
|
#ifndef MINDSPORE_CORE_IR_ANF_H_
|
|
#define MINDSPORE_CORE_IR_ANF_H_
|
|
|
|
#include <functional>
|
|
#include <string>
|
|
#include <vector>
|
|
#include <memory>
|
|
#include <unordered_map>
|
|
#include <utility>
|
|
#include <set>
|
|
|
|
#include "base/base.h"
|
|
#include "base/user_data.h"
|
|
#include "ir/kernel_info_dev.h"
|
|
#include "ir/scope.h"
|
|
#include "utils/info.h"
|
|
|
|
// A MindSpore ANF IR defined here.
|
|
// with BNF followed:
|
|
// <ANode> ::= Scalar | Named | Tensor | Var |
|
|
// Prim | MetaFuncGraph | FuncGraph | Type|
|
|
// Shape | Param
|
|
// <CNode> ::= (<ANode> ...)
|
|
// <AnfNode> ::= <CNode> | <ANode>
|
|
// ANode: Atomic Node
|
|
// CNode: Complex Node
|
|
namespace mindspore {
|
|
namespace abstract {
|
|
class BaseShape;
|
|
class AbstractBase;
|
|
} // namespace abstract
|
|
using BaseShapePtr = std::shared_ptr<abstract::BaseShape>;
|
|
using AbstractBasePtr = std::shared_ptr<abstract::AbstractBase>;
|
|
using AbstractBasePtrList = std::vector<AbstractBasePtr>;
|
|
|
|
class Value;
|
|
using ValuePtr = std::shared_ptr<Value>;
|
|
using ValuePtrList = std::vector<ValuePtr>;
|
|
|
|
class ValueNode;
|
|
using ValueNodePtr = std::shared_ptr<ValueNode>;
|
|
|
|
class CNode;
|
|
using CNodePtr = std::shared_ptr<CNode>;
|
|
|
|
class FuncGraph;
|
|
using FuncGraphSet = OrderedSet<FuncGraphPtr>;
|
|
using FuncGraphPtrList = std::vector<FuncGraphPtr>;
|
|
|
|
class Primitive;
|
|
using PrimitivePtr = std::shared_ptr<Primitive>;
|
|
|
|
class BaseRef;
|
|
|
|
class Var;
|
|
using VarPtr = std::shared_ptr<Var>;
|
|
|
|
class AnfIrVisitor;
|
|
|
|
class ParamInfo;
|
|
using ParamInfoPtr = std::shared_ptr<ParamInfo>;
|
|
|
|
// AnfNode is the basic class of the IR definition derived from Base.
|
|
// Only two types of nodes are derived: CNode and ANode.
|
|
// Methods:
|
|
// func_graph: return FuncGraph that this AnfNode belongs to.
|
|
// scope: return the scope namespace of this AnfNode. Set it using set_scope.
|
|
// abstract: return the cached inferred abstract value. It contains type, shape
|
|
// value. Set New cache using set_abstract.
|
|
// intermediate_abstract: return the cached inferring abstract value.
|
|
// Type/Shape: return the related info of this AnfNode. When this AnfNode is an
|
|
// input of other CNodes, you can get the related info by this method.
|
|
// debug_info: return the information retrived from parser. Set it using set_debug_info.
|
|
// fullname_with_scope: return the detailed debug info.
|
|
class AnfNode : public Base {
|
|
public:
|
|
explicit AnfNode(const FuncGraphPtr &func_graph)
|
|
: func_graph_(FuncGraphWeakPtr(func_graph)),
|
|
abstract_(nullptr),
|
|
intermediate_abstract_(nullptr),
|
|
debug_info_(std::make_shared<NodeDebugInfo>()),
|
|
fullname_with_scope_(""),
|
|
hash_(std::hash<const AnfNode *>()),
|
|
kernel_info_(nullptr),
|
|
stage_(-1) {
|
|
scope_ = ScopeManager::GetInstance().GetCurrentScope();
|
|
}
|
|
|
|
~AnfNode() override = default;
|
|
MS_DECLARE_PARENT(AnfNode, Base);
|
|
|
|
virtual void accept(AnfIrVisitor *) {}
|
|
FuncGraphPtr func_graph() const { return func_graph_.lock(); }
|
|
|
|
void set_func_graph(const FuncGraphPtr &func_graph) { func_graph_ = FuncGraphWeakPtr(func_graph); }
|
|
|
|
ScopePtr scope() { return scope_; }
|
|
void set_scope(const ScopePtr &scope) { scope_ = scope; }
|
|
|
|
const KernelInfoDevice *kernel_info() const { return kernel_info_.get(); }
|
|
KernelInfoDevice *kernel_info() { return kernel_info_.get(); }
|
|
const KernelInfoDevicePtr &kernel_info_ptr() { return kernel_info_; }
|
|
void set_kernel_info(const KernelInfoDevicePtr &kernel_info) { kernel_info_ = kernel_info; }
|
|
|
|
AbstractBasePtr abstract() const { return abstract_; }
|
|
void set_abstract(const AbstractBasePtr &abs) { abstract_ = abs; }
|
|
|
|
AbstractBasePtr intermediate_abstract() { return intermediate_abstract_; }
|
|
void set_intermediate_abstract(const AbstractBasePtr &abs) { intermediate_abstract_ = abs; }
|
|
|
|
NodeDebugInfoPtr debug_info() {
|
|
MS_EXCEPTION_IF_NULL(debug_info_);
|
|
if (debug_info_->get_node() == nullptr) {
|
|
debug_info_->set_node(shared_from_base<AnfNode>());
|
|
}
|
|
return debug_info_;
|
|
}
|
|
void set_debug_info(const NodeDebugInfoPtr &debug_info) {
|
|
debug_info_ = debug_info;
|
|
if (debug_info_->get_node() == nullptr) {
|
|
debug_info_->set_node(shared_from_base<AnfNode>());
|
|
}
|
|
}
|
|
|
|
TypePtr Type() const;
|
|
BaseShapePtr Shape() const;
|
|
|
|
std::size_t hash() const override { return this->hash_(this); }
|
|
virtual std::string fullname_with_scope() { return ""; }
|
|
|
|
virtual std::string DebugString(int recursive_level = 1) const { return ToString(); }
|
|
virtual std::string DebugString(bool recursive) const { return DebugString(recursive ? 1 : 0); }
|
|
std::string ToString() const override;
|
|
void dump() const override { std::cout << DebugString() << std::endl; }
|
|
std::string UniqueId() { return std::to_string(debug_info()->unique_id()); }
|
|
std::string UniqueIdThroughCopy() { return std::to_string(debug_info()->unique_id_through_copy()); }
|
|
virtual bool operator==(const AnfNode &other) const { return &other == this; }
|
|
friend std::ostream &operator<<(std::ostream &os, const AnfNode &node) {
|
|
os << node.ToString();
|
|
return os;
|
|
}
|
|
size_t seen_{0};
|
|
size_t extra_seen_{0};
|
|
|
|
template <typename T>
|
|
void set_user_data(const std::string &key, const std::shared_ptr<T> &value) {
|
|
user_data_.set<T>(key, value);
|
|
}
|
|
|
|
template <typename T>
|
|
void set_user_data(const std::shared_ptr<T> &value) {
|
|
user_data_.set<T>(T::key, value);
|
|
}
|
|
|
|
template <typename T>
|
|
std::shared_ptr<T> user_data(const std::string &key) const {
|
|
return user_data_.get<T>(key);
|
|
}
|
|
|
|
template <typename T>
|
|
std::shared_ptr<T> user_data() const {
|
|
return user_data_.get<T>(T::key);
|
|
}
|
|
|
|
bool has_user_data(const std::string &key) const { return user_data_.has(key); }
|
|
|
|
template <typename T>
|
|
bool has_user_data() const {
|
|
return user_data_.has(T::key);
|
|
}
|
|
|
|
int64_t stage() { return stage_; }
|
|
void set_stage(const int &stage) { stage_ = stage; }
|
|
|
|
protected:
|
|
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
|
|
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
|
|
FuncGraphWeakPtr func_graph_;
|
|
AbstractBasePtr abstract_;
|
|
AbstractBasePtr intermediate_abstract_;
|
|
NodeDebugInfoPtr debug_info_;
|
|
std::string fullname_with_scope_;
|
|
|
|
private:
|
|
std::hash<const AnfNode *> hash_;
|
|
ScopePtr scope_;
|
|
KernelInfoDevicePtr kernel_info_;
|
|
UserData user_data_;
|
|
int64_t stage_;
|
|
};
|
|
|
|
// CNode represents the complex node with a set of arguments.
|
|
// Fields:
|
|
// inputs_: represents all of the inputs for this CNode.
|
|
// Using input(i) to get the index i input.
|
|
// Using inputs() to get all the inputs as a vector.
|
|
// Using add_input(input) to append a new input for a CNode.
|
|
// Using set_input(i, input) to change some input of these inputs.
|
|
// Using set_inputs(inputs) to refresh all of the inputs of a CNode.
|
|
// func_graph_as_var_: used in opt pattern matching to match a real FuncGraph.
|
|
// stop_gradient_: a flag used to stop gradient.
|
|
// Using stop_gradient() to get this flag, mainly used in ad.
|
|
// Using set_stop_gradient() to set this flag.
|
|
class CNode : public AnfNode {
|
|
public:
|
|
CNode(const std::vector<AnfNodePtr> &inputs, const FuncGraphPtr &func_graph);
|
|
CNode(const std::vector<AnfNodePtr> &inputs, const VarPtr &func_graph_as_var)
|
|
: AnfNode(nullptr), inputs_(inputs), func_graph_as_var_(func_graph_as_var), stop_gradient_(false) {}
|
|
|
|
~CNode() override = default;
|
|
MS_DECLARE_PARENT(CNode, AnfNode);
|
|
|
|
void accept(AnfIrVisitor *v) override;
|
|
// check whether this cnode has some primitive value as the first input.
|
|
bool IsApply(const PrimitivePtr &) const;
|
|
|
|
const size_t size() const { return inputs_.size(); }
|
|
const AnfNodePtr input(size_t i) const { return inputs_[i]; }
|
|
const std::vector<AnfNodePtr> &inputs() const { return inputs_; }
|
|
void add_input(const AnfNodePtr &input) { inputs_.push_back(input); }
|
|
void set_input(size_t i, const AnfNodePtr &input);
|
|
void set_inputs(const std::vector<AnfNodePtr> &inputs) { inputs_ = inputs; }
|
|
|
|
void add_input_value(const ValuePtr &input_value, const std::string &id) {
|
|
inputs_value_.push_back(std::make_pair(input_value, id));
|
|
}
|
|
void clear_inputs_value() { inputs_value_.clear(); }
|
|
void set_inputs_value(const std::vector<std::pair<ValuePtr, std::string>> &values) { inputs_value_ = values; }
|
|
const std::vector<std::pair<ValuePtr, std::string>> &inputs_value() const { return inputs_value_; }
|
|
|
|
void set_forward(const ValuePtr &forward, const std::string &id) { output_value_ = std::make_pair(forward, id); }
|
|
const std::pair<ValuePtr, std::string> &forward() const { return output_value_; }
|
|
|
|
bool stop_gradient() const { return stop_gradient_; }
|
|
void set_stop_gradient(bool stop_gradient) { stop_gradient_ = stop_gradient; }
|
|
|
|
std::string fullname_with_scope() override;
|
|
void set_fullname_with_scope(const std::string full_name) { fullname_with_scope_ = full_name; }
|
|
std::string DebugString(int recursive_level = 1) const override;
|
|
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
|
|
|
|
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
|
|
bool in_forward_flag() const { return in_forward_flag_; }
|
|
|
|
VarPtr func_graph_as_var() const { return func_graph_as_var_; }
|
|
|
|
private:
|
|
std::vector<AnfNodePtr> inputs_;
|
|
VarPtr func_graph_as_var_;
|
|
bool stop_gradient_;
|
|
bool in_forward_flag_ = false;
|
|
// inputs_value_ store cnode input value and id in pynative mode
|
|
// output_value_ store cnode value and id in pynative mode
|
|
std::vector<std::pair<ValuePtr, std::string>> inputs_value_;
|
|
std::pair<ValuePtr, std::string> output_value_;
|
|
};
|
|
|
|
// ANode represents the atomic node. It's derived Parameter and ValueNode.
|
|
class ANode : public AnfNode {
|
|
public:
|
|
ANode() : AnfNode(nullptr) {}
|
|
explicit ANode(const FuncGraphPtr &func_graph) : AnfNode(func_graph) {}
|
|
virtual ~ANode() = default;
|
|
|
|
MS_DECLARE_PARENT(ANode, AnfNode);
|
|
};
|
|
|
|
// Parameter represents the parameter inputs of a function. They have no value.
|
|
// Attributes:
|
|
// default_param_value_: used to hold the inputting tensor of the model.
|
|
class Parameter : public ANode {
|
|
public:
|
|
explicit Parameter(const FuncGraphPtr &func_graph)
|
|
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), used_graph_count_(0) {}
|
|
~Parameter() override = default;
|
|
MS_DECLARE_PARENT(Parameter, ANode);
|
|
|
|
void accept(AnfIrVisitor *v) override;
|
|
std::string DebugString(int recursive_level = 1) const override;
|
|
std::string name() const { return name_; }
|
|
void set_name(const std::string &name) { name_ = name; }
|
|
std::string fullname_with_scope() override { return name(); }
|
|
|
|
bool has_default() const { return has_default_; }
|
|
void set_default_param(ValuePtr param) {
|
|
default_param_ = param;
|
|
has_default_ = true;
|
|
}
|
|
ValuePtr default_param() const { return default_param_; }
|
|
ParamInfoPtr param_info() const;
|
|
|
|
void IncreaseUsedGraphCount() { used_graph_count_++; }
|
|
void DecreaseUsedGraphCount() { used_graph_count_--; }
|
|
int used_graph_count() const { return used_graph_count_; }
|
|
|
|
bool operator==(const AnfNode &other) const override {
|
|
if (!other.isa<Parameter>()) {
|
|
return false;
|
|
}
|
|
auto p = static_cast<const Parameter &>(other);
|
|
if (name_.length() > 0 && p.name_.length() > 0) {
|
|
return p.name_ == name_;
|
|
}
|
|
return shared_from_this() == other.shared_from_this();
|
|
}
|
|
|
|
void set_used_by_real_kernel() { is_real_kernel_used_ = false; }
|
|
bool is_used_by_real_kernel() { return is_real_kernel_used_; }
|
|
|
|
void set_used_by_dynamic_kernel() { is_used_by_dynamic_kernel_ = true; }
|
|
bool is_used_by_dynamic_kernel() { return is_used_by_dynamic_kernel_; }
|
|
|
|
private:
|
|
std::string name_;
|
|
bool has_default_;
|
|
bool is_real_kernel_used_ = true;
|
|
bool is_used_by_dynamic_kernel_ = false;
|
|
ValuePtr default_param_;
|
|
// The count of graphs using the parameter.
|
|
int used_graph_count_;
|
|
};
|
|
using ParameterPtr = std::shared_ptr<Parameter>;
|
|
|
|
// Value is used to represent the atomic expression mentioned in BNF.
|
|
// It mainly be stored in ValueNode. Value and ValueNode is related definition.
|
|
class Value : public Base {
|
|
public:
|
|
Value() = default;
|
|
explicit Value(const TypePtr t) : type_(t) {}
|
|
Value(const Value &other) : Base(other) { this->type_ = other.type_; }
|
|
~Value() override = default;
|
|
MS_DECLARE_PARENT(Value, Base)
|
|
|
|
TypePtr type() const { return type_; }
|
|
virtual abstract::AbstractBasePtr ToAbstract() { MS_LOG(EXCEPTION) << "ToAbstract error"; }
|
|
|
|
virtual bool operator==(const Value &rhs) const = 0;
|
|
virtual Value &operator=(const Value &other) {
|
|
if (&other == this) {
|
|
return *this;
|
|
}
|
|
this->type_ = other.type_;
|
|
return *this;
|
|
}
|
|
|
|
protected:
|
|
TypePtr type_{nullptr};
|
|
};
|
|
|
|
// ValueNode is used to hold value. Unlike CNode and Parameter, ValueNode
|
|
// does not belong to any particular function graph.
|
|
class ValueNode : public ANode {
|
|
public:
|
|
explicit ValueNode(const ValuePtr &value) : value_(value) {}
|
|
~ValueNode() override = default;
|
|
MS_DECLARE_PARENT(ValueNode, ANode);
|
|
|
|
void accept(AnfIrVisitor *v) override;
|
|
void set_value(const ValuePtr &value) { value_ = value; }
|
|
const ValuePtr &value() const { return value_; }
|
|
std::string fullname_with_scope() override;
|
|
|
|
void set_has_new_value(bool flag) { has_new_value_ = flag; }
|
|
bool has_new_value() const { return has_new_value_; }
|
|
|
|
std::string ToString() const override;
|
|
std::string DebugString(int recursive_level = 1) const override;
|
|
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
|
|
|
|
bool operator==(const AnfNode &other) const override {
|
|
if (!other.isa<ValueNode>()) {
|
|
return false;
|
|
}
|
|
auto v = static_cast<const ValueNode &>(other);
|
|
return *v.value() == *value();
|
|
}
|
|
friend std::ostream &operator<<(std::ostream &os, const ValueNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
os << node->ToString();
|
|
return os;
|
|
}
|
|
|
|
private:
|
|
ValuePtr value_;
|
|
bool has_new_value_ = false;
|
|
};
|
|
|
|
template <typename T>
|
|
struct ImmTraits {};
|
|
|
|
#define IMM_TRAITS(typeimm, prototype) \
|
|
template <> \
|
|
struct ImmTraits<prototype> { \
|
|
using type = typeimm; \
|
|
};
|
|
|
|
inline ValuePtr MakeValue(const ValuePtr &value) { return value; }
|
|
|
|
template <typename S, typename U = typename ImmTraits<S>::type::element_type>
|
|
inline ValuePtr MakeValue(S v) {
|
|
return std::make_shared<U>(v);
|
|
}
|
|
|
|
template <typename S, typename U = typename ImmTraits<S>::type>
|
|
static S GetValue(const ValuePtr &value) {
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
|
|
U imm = value->cast<U>();
|
|
if (imm == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
|
|
}
|
|
return imm->value();
|
|
}
|
|
|
|
template <typename S,
|
|
typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
|
|
S>::type * = nullptr>
|
|
static S GetValue(const ValuePtr &value) {
|
|
MS_EXCEPTION_IF_NULL(value);
|
|
S v = value->cast<S>();
|
|
if (v == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Cast failed, original value: " << value->ToString() << ", type: " << value->type_name();
|
|
}
|
|
return v;
|
|
}
|
|
|
|
std::string GetCNodeFuncName(CNodePtr cnode);
|
|
|
|
// used to check whether an AnfNode is a cnode with a kind of Primitive as first input
|
|
bool IsPrimitiveCNode(const AnfNodePtr &node, const PrimitivePtr &value = nullptr);
|
|
|
|
// used to get PrimitivePtr from a cnode first input
|
|
PrimitivePtr GetCNodePrimitive(const AnfNodePtr &node);
|
|
|
|
// used to check whether an AnfNode is a valuenode having some Primitive value
|
|
bool IsPrimitive(const AnfNodePtr &node, const PrimitivePtr &value);
|
|
|
|
// used to check whether a ValueNode has some kind of value
|
|
template <typename T>
|
|
static bool IsValueNode(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
auto anode = node->cast<ValueNodePtr>();
|
|
if (anode != nullptr) {
|
|
auto value = anode->value();
|
|
if (value == nullptr) {
|
|
MS_LOG(EXCEPTION) << "Const value is nullptr.";
|
|
}
|
|
return value->isa<T>();
|
|
}
|
|
return false;
|
|
}
|
|
|
|
inline ValuePtr GetValueNode(const AnfNodePtr &node) {
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
if (!node->isa<ValueNode>()) {
|
|
return nullptr;
|
|
}
|
|
return node->cast<ValueNodePtr>()->value();
|
|
}
|
|
|
|
template <typename S,
|
|
typename std::enable_if<is_shared_ptr<S>::value && std::is_base_of<Value, typename S::element_type>::value,
|
|
S>::type * = nullptr>
|
|
inline S GetValueNode(const AnfNodePtr &node) {
|
|
auto value = GetValueNode(node);
|
|
if (value == nullptr) {
|
|
return nullptr;
|
|
}
|
|
auto s = value->cast<S>();
|
|
return s;
|
|
}
|
|
|
|
size_t NewSeenGeneration();
|
|
|
|
namespace id_generator {
|
|
std::string get_id(const AnfNodePtr &node);
|
|
void reset_id();
|
|
} // namespace id_generator
|
|
using TaggedNodeMap = std::unordered_map<AnfNodePtr, size_t>;
|
|
using TaggedGraph = std::pair<FuncGraphPtr, TaggedNodeMap>;
|
|
std::string GetCNodeTarget(const AnfNodePtr &node);
|
|
bool ContainMultiTarget(const std::vector<AnfNodePtr> &nodes);
|
|
struct GraphSegment {
|
|
GraphSegment(const std::vector<AnfNodePtr> &nodes, bool is_cut) : nodes_(nodes), is_cut_(is_cut) {}
|
|
void AddPreSegment(const std::shared_ptr<GraphSegment> &segment) { (void)pre_segments_.insert(segment); }
|
|
std::vector<AnfNodePtr> nodes_;
|
|
std::set<std::shared_ptr<GraphSegment>> pre_segments_;
|
|
bool is_cut_{false};
|
|
uint32_t graph_id_{0};
|
|
};
|
|
using GraphSegmentPtr = std::shared_ptr<GraphSegment>;
|
|
} // namespace mindspore
|
|
|
|
#endif // MINDSPORE_CORE_IR_ANF_H_
|