!3129 Decouple ir from frontend

Merge pull request !3129 from hewei/decouple_ir_frontend
pull/3129/MERGE
mindspore-ci-bot 5 years ago committed by Gitee
commit a2bf5a322e

@ -27,6 +27,7 @@
#include "runtime/device/kernel_info.h"
#include "utils/graph_utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "frontend/parallel/ops_info/operator_info.h"
namespace mindspore {
const std::string ToShortString(const TypeId &typeId) {
@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo
return;
}
auto operator_info = node->operator_info();
auto operator_info = node->GetUserData<parallel::OperatorInfo>();
if (operator_info == nullptr) {
return;
}

@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) {
if (graph_obj == nullptr || node == nullptr) {
return;
}
auto distributed_operation_info = node->operator_info();
auto distributed_operation_info = node->GetUserData<parallel::OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

@ -50,7 +50,7 @@ std::unordered_set<CNodePtr> FindCNodesWithPara(const AnfNodePtr &para, uint32_t
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
(void)cnode_set.emplace(cnode);
} else {
auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1);
@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi
return cnode_dist;
}
auto operator_info = cnode->GetUserData<OperatorInfo>();
MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode)
<< " operator_info: " << (cnode->operator_info() != nullptr);
<< " operator_info: " << (operator_info != nullptr);
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode();
if (IsParallelCareNode(cnode) && (operator_info != nullptr)) {
auto cost = operator_info->GetForwardMemoryCostFromCNode();
MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost;
if (allreduce_graph_.NodeInGraph(cnode)) {

@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) {
}
auto para_ptr = node_ptr->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(para_ptr);
auto layout_ptr = para_ptr->tensor_layout();
auto layout_ptr = para_ptr->GetUserData<TensorLayout>();
if (layout_ptr == nullptr) {
MS_LOG(ERROR) << "layout_ptr is nullptr!";
return FAILED;

@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) {
for (auto para : graph_params) {
std::string name = std::static_pointer_cast<Parameter>(para)->name();
std::shared_ptr<parallel::TensorLayout> tensor_layout = std::static_pointer_cast<Parameter>(para)->tensor_layout();
auto tensor_layout = para->GetUserData<parallel::TensorLayout>();
if (tensor_layout == nullptr) {
MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name;
} else {
@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) {
if (node->isa<CNode>()) {
auto cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
auto distributed_operation_info = cnode->operator_info();
auto distributed_operation_info = cnode->GetUserData<OperatorInfo>();
if (distributed_operation_info != nullptr) {
auto strategyPtr = distributed_operation_info->strategy();
if (strategyPtr != nullptr) {

@ -163,6 +163,9 @@ class OperatorInfo {
const std::string &type() const { return type_; }
const std::unordered_map<std::string, ValuePtr> &attrs() const { return attrs_; }
// Key for user data.
constexpr static char key[] = "OpInfo";
protected:
// needed by rec_parser
std::string type_;

@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector<AnfNodePtr> &all_node
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
std::vector<std::string> inputs_tensor_name = ExtractInputsTensorName(cnode);
entire_costgraph->AddOperator(operator_info);
(void)cnode->set_operator_info(operator_info);
cnode->SetUserData<OperatorInfo>(operator_info);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name();
@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector<AnfNodePtr> &all_no
MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name()
<< " does not match the Prim: " << prim->name();
}
(void)cnode->set_operator_info(current_op_ptr);
cnode->SetUserData<OperatorInfo>(current_op_ptr);
MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId()
<< " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy()
<< " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name();
@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
size_t edge_count = 0;
auto node_op_info = cnode->GetUserData<OperatorInfo>();
for (size_t i = 1; i < inputs.size(); ++i) {
auto prev_cnode = inputs[i]->cast<CNodePtr>();
bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode<Primitive>(prev_cnode->input(0)));
@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
while (bool_result) {
if (IsAutoParallelCareNode(prev_cnode)) {
std::string edge_name =
prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name();
auto prev_op_info = prev_cnode->GetUserData<OperatorInfo>();
std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) {
break;
@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
if (follow_strategy) {
// Redistribution in not allowed on the edge.
// Elementwise operators have the same strategy as their previous operators.
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
output_index, i - 1, false, true);
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true);
} else {
edge_ptr = std::make_shared<Edge>(edge_name, prev_cnode->operator_info(), cnode->operator_info(),
output_index, i - 1, false);
edge_ptr = std::make_shared<Edge>(edge_name, prev_op_info, node_op_info, output_index, i - 1, false);
}
// Init costs for this edge
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
}
cnode->operator_info()->AddPrevEdge(edge_ptr);
prev_cnode->operator_info()->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and "
<< cnode->operator_info()->name();
node_op_info->AddPrevEdge(edge_ptr);
prev_op_info->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and "
<< node_op_info->name();
edge_count++;
break;
@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
(IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND);
}
}
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name();
MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name();
}
MS_LOG(INFO) << "Constructing edges for cost graph ends.";
@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
for (auto &target : target_set) {
auto target_cnode = target.first->cast<CNodePtr>();
auto input_index = target.second;
(void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name());
(void)target_without_duplicate.insert(std::to_string(input_index) +
target_cnode->GetUserData<OperatorInfo>()->name());
}
if (target_without_duplicate.size() <= 1) {
continue;
@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
auto target_cnode = target.first->cast<CNodePtr>();
auto prim = GetValueNode<PrimitivePtr>(target_cnode->input(0));
auto input_index = target.second;
auto target_op_info = target_cnode->GetUserData<OperatorInfo>();
std::string edge_name =
std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name();
std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name();
// If the edge between these two operators already has been added, then the edge will not be added again.
if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) {
continue;
}
std::shared_ptr<Edge> edge_ptr = std::make_shared<Edge>(
edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true);
std::shared_ptr<Edge> edge_ptr =
std::make_shared<Edge>(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true);
if (edge_ptr->InitEdgeCost() != SUCCESS) {
MS_LOG(EXCEPTION) << "Edge cost initialization failed";
}
target_cnode->operator_info()->AddPrevEdge(edge_ptr);
target_op_info->AddPrevEdge(edge_ptr);
tmp_identity_ptr->AddSuccEdge(edge_ptr);
entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr);
entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr);
MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and "
<< target_cnode->operator_info()->name();
<< target_op_info->name();
add_identity_edge = true;
}
if (new_identity && add_identity_edge) {
@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
if (prim->name() != RESHAPE) {
return false;
}
return true;
return (prim->name() == RESHAPE);
}
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
if (!IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
*pre_operator_info = cnode->operator_info();
auto node_op_info = cnode->GetUserData<OperatorInfo>();
if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) {
*pre_operator_info = node_op_info;
*out_index = 0;
return true;
}
@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
}
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
*pre_operator_info = pre_cnode->operator_info();
auto pre_op_info = pre_cnode->GetUserData<OperatorInfo>();
if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) {
*pre_operator_info = pre_op_info;
return true;
}
return false;
@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
auto op_info = use_apply->GetUserData<OperatorInfo>();
if (IsParallelCareNode(use_apply) && (op_info != nullptr)) {
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
*next_operator_info = use_apply->operator_info();
*next_operator_info = op_info;
*in_index = node_pair.second - 1;
return true;
}
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (use_apply->operator_info() != nullptr);
<< " " << (op_info != nullptr);
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
return true;
@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
int32_t out_index = 0;
OperatorInfoPtr pre_operator_info;
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
auto operator_info = cnode->GetUserData<OperatorInfo>();
if (pre_node->isa<Parameter>()) {
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->SetCostForReshapeWithParameter();
pre_operator_info = reshape_info;
@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
}
// set input_layout and output_layout for reshape.
// init reshape and set cost for each input_layout and output_layout.
OperatorInfoPtr operator_info = cnode->operator_info();
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
reshape_info->set_pre_operator_name(pre_operator_info->name());
reshape_info->set_pre_operator_index(out_index);

@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
if (!IsParallelCareNode(node)) {
return nullptr;
}
OperatorInfoPtr distribute_operator = node->operator_info();
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
}
@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
if (prim->name() == GET_NEXT) {
return true;
}
if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) {
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) {
return false;
}
@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) {
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
pre_node);
} else {
@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
MS_EXCEPTION_IF_NULL(node);
MS_EXCEPTION_IF_NULL(next_node);
OperatorInfoPtr op_info = next_node->operator_info();
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>();
MS_EXCEPTION_IF_NULL(op_info);
// If the shape of tensor is [] or [1], no need to split it.
@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
// step1:get graph manager distribute_operator
OperatorInfoPtr distribute_operator = node->operator_info();
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
}
@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
(void)prim->SetAttrs(attrs);
}
if (index == replace_op.size() - 1) {
(void)replace_node->set_operator_info(node->operator_info());
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>());
}
replace_node->set_in_forward_flag(true);
replace_input[0]->set_scope(scope);
@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
auto pre_cnode = pre_node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
pre_node = pre_cnode->input(1);
}
@ -1204,7 +1204,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
return node_pair;
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
return FindParallelCareNode(node_pair.first);
@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
CNodePtr cnode = res.first->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = cnode->operator_info();
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>();
if (distribute_operator == nullptr) {
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
}
@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr &parameter, const std::pair<AnfNodePtr, i
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
MS_EXCEPTION_IF_NULL(parameter_ptr);
parameter_ptr->set_tensor_layout(std::make_shared<TensorLayout>(tensor_layout));
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
}
void CoverSliceShape(const FuncGraphPtr &root) {
@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
if (found_be_cloned_parameter) {
// set the shape and tensor layout for cloned parameter
cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout());
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>());
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
(*operator_).set_outputs_dtype(cnode->Type());
(*operator_).set_cnode(cnode);
if (prim->name() == RESHAPE) {
(void)cnode->set_operator_info(operator_);
cnode->SetUserData<OperatorInfo>(operator_);
continue;
}
// load strategy checkpoint
@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
if (operator_->Init(strategyPtr) == FAILED) {
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
}
(void)cnode->set_operator_info(operator_);
cnode->SetUserData<OperatorInfo>(operator_);
} else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
}
@ -1542,13 +1542,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
if (node_prim->name() == DEPEND && node_pair.second != 1) {
continue;
}
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) {
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
auto layout = GetInputLayoutFromCNode(node_pair);
return std::make_shared<TensorLayout>(layout);
}
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
<< " " << (use_apply->operator_info() != nullptr);
<< " " << use_apply->HasUserData<OperatorInfo>();
auto layout_ptr = FindNextLayout(use_apply);
if (layout_ptr) {
@ -1580,7 +1580,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
@ -1624,7 +1624,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
if (!IsValueNode<Primitive>(cnode->input(0))) {
return nullptr;
}
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
if (!layout_ptr) {
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
continue;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
continue;
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
if (operator_info == nullptr) {
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
}
@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
// return -> cast
if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) {
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(pre_cnode);
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
return ret;
}
OperatorInfoPtr operator_info = loss_cnode->operator_info();
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>();
MS_EXCEPTION_IF_NULL(operator_info);
TensorInfo loss_grad_tensor_info;
size_t op_output_size = operator_info->outputs_tensor_info().size();
@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
if (sens_tensor_node->isa<Parameter>()) {
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
}
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
return;
@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
cloned_abstract->set_shape(parallel_shape);
sens_tensor_node->set_abstract(cloned_abstract);
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
sens_tensor_param->set_tensor_layout(std::make_shared<TensorLayout>(loss_grad_layout));
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
return;
}
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
}
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
MS_EXCEPTION_IF_NULL(prim);
OperatorInfoPtr operator_info = cnode->operator_info();
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
if (operator_info) {
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
continue;

@ -83,6 +83,9 @@ class TensorLayout {
TensorLayout SqueezeShape() const;
// Key for user data.
constexpr static char key[] = "TLayout";
private:
std::shared_ptr<TensorLayout> ExpandTensorShapeWithoutExtendDeviceArrangement(
const Arrangement &expanded_shape) const;

@ -0,0 +1,160 @@
/**
* Copyright 2019-2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_
#define MINDSPORE_CORE_OPERATOR_OPS_H_
#include <iostream>
#include <string>
#include <memory>
#include "ir/anf.h"
#include "ir/primitive.h"
namespace mindspore {
namespace prim {
// Maths
inline const PrimitivePtr kPrimTensorAdd = std::make_shared<Primitive>("TensorAdd");
inline const PrimitivePtr kPrimMatMul = std::make_shared<Primitive>("MatMul");
inline const PrimitivePtr kPrimBatchMatMul = std::make_shared<Primitive>("BatchMatMul");
inline const PrimitivePtr kPrimMaximumGrad = std::make_shared<Primitive>("MaximumGrad");
inline const PrimitivePtr kPrimMinimumGrad = std::make_shared<Primitive>("MinimumGrad");
inline const PrimitivePtr kPrimReduceMean = std::make_shared<Primitive>("ReduceMean");
inline const PrimitivePtr kPrimReduceSum = std::make_shared<Primitive>("ReduceSum");
inline const PrimitivePtr kPrimReduceAll = std::make_shared<Primitive>("ReduceAll");
inline const PrimitivePtr kPrimReduceMax = std::make_shared<Primitive>("ReduceMax");
inline const PrimitivePtr kPrimReduceMin = std::make_shared<Primitive>("ReduceMin");
inline const PrimitivePtr kPrimNeg = std::make_shared<Primitive>("Neg");
inline const PrimitivePtr kPrimSub = std::make_shared<Primitive>("Sub");
inline const PrimitivePtr kPrimMul = std::make_shared<Primitive>("Mul");
inline const PrimitivePtr kPrimMinimum = std::make_shared<Primitive>("Minimum");
inline const PrimitivePtr kPrimMaximum = std::make_shared<Primitive>("Maximum");
inline const PrimitivePtr kPrimSquare = std::make_shared<Primitive>("Square");
inline const PrimitivePtr kPrimCumSum = std::make_shared<Primitive>("CumSum");
inline const PrimitivePtr kPrimCumProd = std::make_shared<Primitive>("CumProd");
inline const PrimitivePtr kPrimSubscalar = std::make_shared<Primitive>("Subscalar");
inline const PrimitivePtr kPrimInplaceAdd = std::make_shared<Primitive>("InplaceAdd");
inline const PrimitivePtr kPrimInplaceSub = std::make_shared<Primitive>("InplaceSub");
inline const PrimitivePtr kPrimPow = std::make_shared<Primitive>("Pow");
inline const PrimitivePtr kPrimRealDiv = std::make_shared<Primitive>("RealDiv");
inline const PrimitivePtr kPrimSqrt = std::make_shared<Primitive>("Sqrt");
inline const PrimitivePtr kPrimReciprocal = std::make_shared<Primitive>("Reciprocal");
inline const PrimitivePtr kPrimExpandDims = std::make_shared<Primitive>("ExpandDims");
// Statements
inline const PrimitivePtr kPrimReturn = std::make_shared<Primitive>("return");
inline const PrimitivePtr kPrimSwitch = std::make_shared<Primitive>("switch");
inline const PrimitivePtr kPrimSwitchLayer = std::make_shared<Primitive>("switch_layer");
inline const PrimitivePtr kPrimAssign = std::make_shared<Primitive>("Assign");
inline const PrimitivePtr kPrimAssignAdd = std::make_shared<Primitive>("AssignAdd");
inline const PrimitivePtr kPrimAssignSub = std::make_shared<Primitive>("AssignSub");
inline const PrimitivePtr kPrimSelect = std::make_shared<Primitive>("Select");
inline const PrimitivePtr kPrimCall = std::make_shared<Primitive>("call");
// Structures
inline const PrimitivePtr kPrimStringEqual = std::make_shared<Primitive>("string_equal");
inline const PrimitivePtr kPrimStringConcat = std::make_shared<Primitive>("string_concat");
inline const PrimitivePtr kPrimMakeTuple = std::make_shared<Primitive>("make_tuple");
inline const PrimitivePtr kPrimMakeDict = std::make_shared<Primitive>("make_dict");
inline const PrimitivePtr kPrimMakeList = std::make_shared<Primitive>("make_list");
inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared<Primitive>("make_keyword_arg");
inline const PrimitivePtr kPrimMakeSlice = std::make_shared<Primitive>("make_slice");
inline const PrimitivePtr kPrimMakeRecord = std::make_shared<Primitive>("make_record");
inline const PrimitivePtr kPrimTupleGetItem = std::make_shared<Primitive>("tuple_getitem");
inline const PrimitivePtr kPrimListGetItem = std::make_shared<Primitive>("list_getitem");
inline const PrimitivePtr kPrimArrayGetItem = std::make_shared<Primitive>("array_getitem");
inline const PrimitivePtr kPrimTupleSetItem = std::make_shared<Primitive>("tuple_setitem");
inline const PrimitivePtr kPrimListSetItem = std::make_shared<Primitive>("list_setitem");
inline const PrimitivePtr kPrimArraySetItem = std::make_shared<Primitive>("array_setitem");
inline const PrimitivePtr kPrimDictGetItem = std::make_shared<Primitive>("dict_getitem");
inline const PrimitivePtr kPrimDictSetItem = std::make_shared<Primitive>("dict_setitem");
inline const PrimitivePtr kPrimListAppend = std::make_shared<Primitive>("list_append");
inline const PrimitivePtr kPrimGetAttr = std::make_shared<Primitive>("getattr");
inline const PrimitivePtr kPrimTupleLen = std::make_shared<Primitive>("tuple_len");
inline const PrimitivePtr kPrimDictLen = std::make_shared<Primitive>("dict_len");
inline const PrimitivePtr kPrimListLen = std::make_shared<Primitive>("list_len");
inline const PrimitivePtr kPrimArrayLen = std::make_shared<Primitive>("array_len");
inline const PrimitivePtr kPrimListMap = std::make_shared<Primitive>("list_map");
inline const PrimitivePtr kPrimListReduce = std::make_shared<Primitive>("list_reduce");
inline const PrimitivePtr kPrimTupleReversed = std::make_shared<Primitive>("tuple_reversed");
inline const PrimitivePtr kPrimTileShape = std::make_shared<Primitive>("tile_shape");
inline const PrimitivePtr kPrimReducedShape = std::make_shared<Primitive>("reduced_shape");
inline const PrimitivePtr kPrimTupleDiv = std::make_shared<Primitive>("tuple_div");
inline const PrimitivePtr kPrimTupleToArray = std::make_shared<Primitive>("tuple_to_array");
inline const PrimitivePtr kPrimShapeMul = std::make_shared<Primitive>("shape_mul");
inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared<Primitive>("generate_shape_index");
inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared<Primitive>("generate_inverse_index");
inline const PrimitivePtr kPrimTupleEqual = std::make_shared<Primitive>("tuple_equal");
inline const PrimitivePtr kPrimListEqual = std::make_shared<Primitive>("list_equal");
inline const PrimitivePtr kPrimMakeRange = std::make_shared<Primitive>("make_range");
inline const PrimitivePtr kPrimStopGradient = std::make_shared<Primitive>("stop_gradient");
inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared<Primitive>("extract_keyword_arg");
// Debug ops
inline const PrimitivePtr kPrimScalarSummary = std::make_shared<Primitive>("ScalarSummary");
inline const PrimitivePtr kPrimImageSummary = std::make_shared<Primitive>("ImageSummary");
inline const PrimitivePtr kPrimTensorSummary = std::make_shared<Primitive>("TensorSummary");
inline const PrimitivePtr kPrimHistogramSummary = std::make_shared<Primitive>("HistogramSummary");
inline const PrimitivePtr kPrimDebug = std::make_shared<Primitive>("Debug");
// Other miscellaneous
inline const PrimitivePtr kPrimJ = std::make_shared<Primitive>("J");
inline const PrimitivePtr kPrimDepend = std::make_shared<Primitive>("Depend");
inline const PrimitivePtr kPrimPartial = std::make_shared<Primitive>("Partial");
inline const PrimitivePtr kPrimIdentity = std::make_shared<Primitive>("identity");
inline const PrimitivePtr kPrimEnvSetItem = std::make_shared<Primitive>("env_setitem");
inline const PrimitivePtr kPrimEnvGetItem = std::make_shared<Primitive>("env_getitem");
inline const PrimitivePtr kPrimEnvAdd = std::make_shared<Primitive>("env_add");
inline const PrimitivePtr kPrimMakeRefKey = std::make_shared<Primitive>("MakeRefKey");
inline const PrimitivePtr kPrimGetRefKey = std::make_shared<Primitive>("get_ref_key");
inline const PrimitivePtr kPrimGetRefValue = std::make_shared<Primitive>("get_ref_value");
inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared<Primitive>("get_ref_origin");
inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared<Primitive>("InsertGradientOf");
inline const PrimitivePtr kPrimHookBackward = std::make_shared<Primitive>("HookBackward");
inline const PrimitivePtr kPrimPrintShapeType = std::make_shared<Primitive>("PrintShapeType");
inline const PrimitivePtr kPrimSameTypeShape = std::make_shared<Primitive>("SameTypeShape");
inline const PrimitivePtr kPrimCheckBprop = std::make_shared<Primitive>("CheckBprop");
inline const PrimitivePtr kPrimPrint = std::make_shared<Primitive>("Print");
inline const PrimitivePtr kPrimMakeRef = std::make_shared<Primitive>("make_ref");
inline const PrimitivePtr kPrimStateSetItem = std::make_shared<Primitive>("state_setitem");
inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared<Primitive>("BroadcastGradientArgs");
inline const PrimitivePtr kPrimControlDepend = std::make_shared<Primitive>("ControlDepend");
inline const PrimitivePtr kPrimIs_ = std::make_shared<Primitive>("is_");
inline const PrimitivePtr kPrimIsNot = std::make_shared<Primitive>("is_not");
inline const PrimitivePtr kPrimInDict = std::make_shared<Primitive>("in_dict");
inline const PrimitivePtr kPrimNotInDict = std::make_shared<Primitive>("not_in_dict");
inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared<Primitive>("mixed_precision_cast");
inline const PrimitivePtr kPrimIsConsant = std::make_shared<Primitive>("is_constant");
inline const PrimitivePtr kPrimEquivFormat = std::make_shared<Primitive>("EquivFormat");
class DoSignaturePrimitive : public Primitive {
public:
explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function)
: Primitive("S-Prim-" + name), function_(function) {}
~DoSignaturePrimitive() override = default;
MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive)
const ValuePtr function() const { return function_; }
private:
ValuePtr function_;
};
using DoSignaturePrimitivePtr = std::shared_ptr<DoSignaturePrimitive>;
} // namespace prim
} // namespace mindspore
#endif // MINDSPORE_CORE_OPERATOR_OPS_H_

@ -0,0 +1,52 @@
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_USER_DATA_H_
#define MINDSPORE_CORE_USER_DATA_H_
#include <string>
#include <memory>
#include <map>
namespace mindspore {
class UserData {
public:
template <typename T>
void set(const std::string &key, const std::shared_ptr<T> &value) {
if (value == nullptr) {
data_.erase(key);
} else {
data_.insert_or_assign(key, value);
}
}
template <typename T>
std::shared_ptr<T> get(const std::string &key) const {
auto iter = data_.find(key);
if (iter == data_.end()) {
return nullptr;
}
return std::static_pointer_cast<T>(iter->second);
}
bool has(const std::string &key) const { return data_.find(key) != data_.end(); }
private:
std::map<std::string, std::shared_ptr<void>> data_;
};
} // namespace mindspore
#endif // MINDSPORE_CORE_USER_DATA_H_

@ -26,7 +26,6 @@
#include "ir/func_graph.h"
#include "ir/primitive.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
namespace mindspore {
// namespace to support intermediate representation definition

@ -27,6 +27,7 @@
#include <utility>
#include "base/base.h"
#include "base/user_data.h"
#include "ir/kernel_info_dev.h"
#include "ir/scope.h"
#include "debug/info.h"
@ -41,12 +42,6 @@
// ANode: Atomic Node
// CNode: Complex Node
namespace mindspore {
namespace parallel {
class TensorLayout;
class OperatorInfo;
} // namespace parallel
using OperatorInfoPtr = std::shared_ptr<parallel::OperatorInfo>;
namespace abstract {
class BaseShape;
class AbstractBase;
@ -157,6 +152,31 @@ class AnfNode : public Base {
}
size_t seen_{0};
template <typename T>
void SetUserData(const std::string &key, const std::shared_ptr<T> &value) {
user_data_.set<T>(key, value);
}
template <typename T>
void SetUserData(const std::shared_ptr<T> &value) {
user_data_.set<T>(T::key, value);
}
template <typename T>
std::shared_ptr<T> GetUserData(const std::string &key) const {
return user_data_.get<T>(key);
}
template <typename T>
std::shared_ptr<T> GetUserData() const {
return user_data_.get<T>(T::key);
}
bool HasUserData(const std::string &key) const { return user_data_.has(key); }
template <typename T>
bool HasUserData() const { return user_data_.has(T::key); }
protected:
// Hold a weak ref to Graph as Graph also hold ref to AnfNode.
// Otherwise, func_graph_ and AnfNode will make a reference cycle.
@ -170,6 +190,7 @@ class AnfNode : public Base {
std::hash<const AnfNode *> hash_;
ScopePtr scope_;
KernelInfoDevicePtr kernel_info_;
UserData user_data_;
};
// CNode represents the complex node with a set of arguments.
@ -212,9 +233,6 @@ class CNode : public AnfNode {
std::string DebugString(int recursive_level = 1) const override;
std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); }
OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info);
OperatorInfoPtr operator_info() { return operator_info_; }
void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; }
bool in_forward_flag() const { return in_forward_flag_; }
@ -224,7 +242,6 @@ class CNode : public AnfNode {
std::vector<AnfNodePtr> inputs_;
VarPtr func_graph_as_var_;
bool stop_gradient_;
OperatorInfoPtr operator_info_ = nullptr;
bool in_forward_flag_ = false;
};
@ -244,7 +261,7 @@ class ANode : public AnfNode {
class Parameter : public ANode {
public:
explicit Parameter(const FuncGraphPtr &func_graph)
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {}
: ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {}
~Parameter() override = default;
MS_DECLARE_PARENT(Parameter, ANode);
@ -261,11 +278,6 @@ class Parameter : public ANode {
}
ParamValuePtr default_param() const { return default_param_; }
std::shared_ptr<parallel::TensorLayout> tensor_layout() const { return tensor_layout_; }
void set_tensor_layout(const std::shared_ptr<parallel::TensorLayout> &tensor_layout) {
tensor_layout_ = tensor_layout;
}
bool operator==(const AnfNode &other) const override {
if (!other.isa<Parameter>()) {
return false;
@ -281,7 +293,6 @@ class Parameter : public ANode {
std::string name_;
bool has_default_;
ParamValuePtr default_param_;
std::shared_ptr<parallel::TensorLayout> tensor_layout_;
};
using ParameterPtr = std::shared_ptr<Parameter>;

@ -23,8 +23,7 @@
#include "ir/visitor.h"
#include "ir/func_graph.h"
#include "frontend/operator/ops.h"
#include "frontend/parallel/ops_info/ops_utils.h"
#include "base/core_ops.h"
#include "debug/label.h"
namespace mindspore {
@ -37,18 +36,6 @@ std::string AnfNode::ToString() const {
return mindspore::label_manage::Label(const_cast<AnfNode *>(this)->shared_from_base<AnfNode>()->debug_info());
}
OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) {
if (operator_info_ != nullptr) {
MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name()
<< ", using the new one: " << operator_info->name();
auto old_ptr = operator_info_;
operator_info_ = operator_info;
return old_ptr;
}
operator_info_ = operator_info;
return nullptr;
}
std::string CNode::fullname_with_scope() {
// if full name is set, return its name immediately
if (!fullname_with_scope_.empty()) {

@ -24,7 +24,6 @@
#include "debug/trace.h"
#include "ir/manager.h"
#include "frontend/operator/ops.h"
#include "utils/ordered_set.h"
#include "utils/convert_utils_base.h"

@ -20,7 +20,7 @@
#include "ir/manager.h"
#include "ir/param_value.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
#include "utils/convert_utils_base.h"
#include "utils/log_adapter.h"
#include "utils/profile.h"

@ -22,7 +22,7 @@
#include "ir/manager.h"
#include "ir/func_graph_cloner.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
#include "utils/ordered_set.h"
#include "abstract/abstract_value.h"
#include "debug/anf_ir_dump.h"

@ -26,7 +26,7 @@
#include "ir/func_graph.h"
#include "utils/profile.h"
#include "utils/convert_utils_base.h"
#include "frontend/operator/ops.h"
#include "base/core_ops.h"
namespace mindspore {

@ -17,10 +17,8 @@
*/
#include "ir/meta_func_graph.h"
#include "pipeline/jit/static_analysis/static_analysis.h"
#include "pipeline/jit/static_analysis/abstract_function.h"
#include "base/core_ops.h"
#include "utils/context/ms_context.h"
#include "frontend/operator/ops.h"
// namespace to support intermediate representation definition
namespace mindspore {

@ -22,9 +22,9 @@
#include <tuple>
#include <vector>
#include "frontend/operator/ops.h"
#include "frontend/optimizer/optimizer.h"
#include "ir/anf.h"
#include "ir/optimizer_caller.h"
#include "base/core_ops.h"
namespace mindspore {
///

@ -25,7 +25,6 @@
#include "ir/dtype/type.h"
#include "abstract/abstract_value.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "utils/base_ref_extends.h"
namespace mindspore {

@ -18,7 +18,6 @@
#include <mutex>
#include <utility>
#include "ir/signature.h"
#include "frontend/operator/ops.h"
#include "./common.h"
#include "pipeline/jit/parse/python_adapter.h"
#include "pipeline/jit/parse/data_converter.h"

@ -28,7 +28,6 @@
#include <type_traits>
#include <typeinfo>
#include "runtime/device/device_address.h"
#include "abstract/abstract_value.h"
namespace mindspore {

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save