diff --git a/mindspore/ccsrc/debug/anf_ir_dump.cc b/mindspore/ccsrc/debug/anf_ir_dump.cc index 42d372cefb..563000cd1f 100644 --- a/mindspore/ccsrc/debug/anf_ir_dump.cc +++ b/mindspore/ccsrc/debug/anf_ir_dump.cc @@ -27,6 +27,7 @@ #include "runtime/device/kernel_info.h" #include "utils/graph_utils.h" #include "backend/session/anf_runtime_algorithm.h" +#include "frontend/parallel/ops_info/operator_info.h" namespace mindspore { const std::string ToShortString(const TypeId &typeId) { @@ -266,7 +267,7 @@ void DumpParallelInfo(const CNodePtr &node, const std::shared_ptroperator_info(); + auto operator_info = node->GetUserData(); if (operator_info == nullptr) { return; } diff --git a/mindspore/ccsrc/debug/draw.cc b/mindspore/ccsrc/debug/draw.cc index ff8132fb28..6e42277d4e 100644 --- a/mindspore/ccsrc/debug/draw.cc +++ b/mindspore/ccsrc/debug/draw.cc @@ -437,7 +437,7 @@ static void DrawParallelInfo(Graphviz *const graph_obj, const CNodePtr &node) { if (graph_obj == nullptr || node == nullptr) { return; } - auto distributed_operation_info = node->operator_info(); + auto distributed_operation_info = node->GetUserData(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { diff --git a/mindspore/ccsrc/frontend/operator/ops.cc b/mindspore/ccsrc/frontend/operator/ops.cc deleted file mode 100755 index bf3d55678e..0000000000 --- a/mindspore/ccsrc/frontend/operator/ops.cc +++ /dev/null @@ -1,293 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "frontend/operator/ops.h" -#include -#include - -namespace mindspore { -// namespace to support primitive operators -namespace prim { -// Arithmetic -const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); -const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); -const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); -const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); -const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); -const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); -const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); -const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); -const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); -const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); -const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); -const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); -const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); -const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); -const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); -const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); - -// Comparisons -const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); -const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); -const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); -const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); -const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); -const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); -const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); -const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); -const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); -const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); -const PrimitivePtr kPrimGreater = std::make_shared("Greater"); -const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); -const PrimitivePtr kPrimLess = std::make_shared("Less"); -const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); -const PrimitivePtr kPrimEqual = std::make_shared("Equal"); -const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); - -// Type introspection -const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); -const PrimitivePtr kPrimHasType = std::make_shared("hastype"); - -// Statements -const PrimitivePtr kPrimSwitch = std::make_shared("switch"); -const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); -const PrimitivePtr kPrimReturn = std::make_shared("return"); -const PrimitivePtr kPrimAssign = std::make_shared("Assign"); -const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); -const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); -const PrimitivePtr kPrimSelect = std::make_shared("Select"); -const PrimitivePtr kPrimCall = std::make_shared("call"); - -const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); -const PrimitivePtr kPrimDot = std::make_shared("dot"); -const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); -const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); -const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); -const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); - -const PrimitivePtr kPrimResolve = std::make_shared("resolve"); -const PrimitivePtr kPrimEmbed = std::make_shared("embed"); -const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); -const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); - -const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); -const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); -const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); - -// Structure -const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); -const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); -const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); -const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); -const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); -const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); -const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); -const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); -const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); -const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); -const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); -const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); -const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); -const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); -const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); -const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); -const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); -const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); -const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); -const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); -const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); -const PrimitivePtr kPrimListLen = std::make_shared("list_len"); -const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); -const PrimitivePtr kPrimListMap = std::make_shared("list_map"); -const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); -const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); - -const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); -const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); -const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); -const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); -const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); -const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); -const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); -const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); -const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); -const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); -const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); - -// Arrays -const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); -const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); -const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); -const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); -const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); -const PrimitivePtr kPrimShape = std::make_shared("Shape"); -const PrimitivePtr kPrimCast = std::make_shared("Cast"); -const PrimitivePtr kPrimConcat = std::make_shared("Concat"); -const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); -const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); -const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); -const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); -const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); -const PrimitivePtr kPrimSize = std::make_shared("Size"); -const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); -const PrimitivePtr kPrimPack = std::make_shared("Pack"); -const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); -const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); -const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); -const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); -const PrimitivePtr kPrimTile = std::make_shared("Tile"); -const PrimitivePtr kPrimAddN = std::make_shared("AddN"); -const PrimitivePtr KPrimTransData = std::make_shared("TransData"); -const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); -const PrimitivePtr kPrimPad = std::make_shared("Pad"); -const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); - -// Maths -const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); -const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); -const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); -const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); -const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); -const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); -const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); -const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); -const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); -const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); -const PrimitivePtr kPrimNeg = std::make_shared("Neg"); -const PrimitivePtr kPrimSub = std::make_shared("Sub"); -const PrimitivePtr kPrimMul = std::make_shared("Mul"); -const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); -const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); -const PrimitivePtr kPrimSquare = std::make_shared("Square"); -const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); -const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); -const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); -const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); -const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); -const PrimitivePtr kPrimPow = std::make_shared("Pow"); -const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); -const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); -const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); -const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); - -// NN -const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); -const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); -const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); -const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); -const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); -const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); -const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); -const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); -const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); -const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); -const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); -const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); -const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); -const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); -const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); -const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); -const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); -const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); -const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); -const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = - std::make_shared("DepthwiseConv2dNativeBackpropFilter"); -const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = - std::make_shared("DepthwiseConv2dNativeBackpropInput"); -const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); -const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = std::make_shared("SoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = - std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); -const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); -const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); -const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); -const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); -const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); -const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); -const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); -const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); -const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); -const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); -const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); -const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); -const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); -const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); -const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); -const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); -const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); -const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); -const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); - -// Other miscellaneous -const PrimitivePtr kPrimIdentity = std::make_shared("identity"); -const PrimitivePtr kPrimPartial = std::make_shared("Partial"); -const PrimitivePtr kPrimJ = std::make_shared("J"); -const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); -const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); -const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); -const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); -const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); -const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); -const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); -const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); -const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); -const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); -const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); -const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); -const PrimitivePtr kPrimPrint = std::make_shared("Print"); - -const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); -const PrimitivePtr kPrimDepend = std::make_shared("Depend"); -const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); - -const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); -const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); -const PrimitivePtr kPrimIs_ = std::make_shared("is_"); -const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); -const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); -const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); -const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); -const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); -const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); - -// Comm ops -const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); -const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); -const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); - -// Debug ops -const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); -const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); -const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); -const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); -const PrimitivePtr kPrimDebug = std::make_shared("Debug"); - -// IndexedSlices -const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); -const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); -const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); -const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); - -// SparseTensor -const PrimitivePtr kPrimMakeSparseTensor = std::make_shared("MakeSparseTensor"); -const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared("SparseTensorGetValues"); -const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared("SparseTensorGetIndices"); -const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared("SparseTensorGetDenseShape"); -} // namespace prim -} // namespace mindspore diff --git a/mindspore/ccsrc/frontend/operator/ops.h b/mindspore/ccsrc/frontend/operator/ops.h index 13649ef2d2..5b73daba11 100755 --- a/mindspore/ccsrc/frontend/operator/ops.h +++ b/mindspore/ccsrc/frontend/operator/ops.h @@ -22,6 +22,7 @@ #include #include "ir/anf.h" #include "ir/primitive.h" +#include "base/core_ops.h" namespace mindspore { // namespace to support primitive operators @@ -31,273 +32,158 @@ ValuePtr GetPythonOps(const std::string &op_name, bool use_signature = false); // Arithmetic -extern const PrimitivePtr kPrimScalarAdd; -extern const PrimitivePtr kPrimScalarSub; -extern const PrimitivePtr kPrimScalarMul; -extern const PrimitivePtr kPrimScalarDiv; -extern const PrimitivePtr kPrimScalarFloordiv; -extern const PrimitivePtr kPrimScalarMod; -extern const PrimitivePtr kPrimScalarPow; -extern const PrimitivePtr kPrimScalarTrunc; -extern const PrimitivePtr kPrimScalarFloor; -extern const PrimitivePtr kPrimScalarUadd; -extern const PrimitivePtr kPrimScalarUsub; -extern const PrimitivePtr kPrimScalarExp; -extern const PrimitivePtr kPrimScalarLog; -extern const PrimitivePtr kPrimScalarSin; -extern const PrimitivePtr kPrimScalarCos; -extern const PrimitivePtr kPrimScalarTan; +inline const PrimitivePtr kPrimScalarAdd = std::make_shared("scalar_add"); +inline const PrimitivePtr kPrimScalarSub = std::make_shared("scalar_sub"); +inline const PrimitivePtr kPrimScalarMul = std::make_shared("scalar_mul"); +inline const PrimitivePtr kPrimScalarDiv = std::make_shared("scalar_div"); +inline const PrimitivePtr kPrimScalarFloordiv = std::make_shared("scalar_floordiv"); +inline const PrimitivePtr kPrimScalarMod = std::make_shared("scalar_mod"); +inline const PrimitivePtr kPrimScalarPow = std::make_shared("scalar_pow"); +inline const PrimitivePtr kPrimScalarTrunc = std::make_shared("scalar_trunc"); +inline const PrimitivePtr kPrimScalarFloor = std::make_shared("scalar_floor"); +inline const PrimitivePtr kPrimScalarUadd = std::make_shared("scalar_uadd"); +inline const PrimitivePtr kPrimScalarUsub = std::make_shared("scalar_usub"); +inline const PrimitivePtr kPrimScalarExp = std::make_shared("scalar_exp"); +inline const PrimitivePtr kPrimScalarLog = std::make_shared("scalar_log"); +inline const PrimitivePtr kPrimScalarSin = std::make_shared("scalar_sin"); +inline const PrimitivePtr kPrimScalarCos = std::make_shared("scalar_cos"); +inline const PrimitivePtr kPrimScalarTan = std::make_shared("scalar_tan"); // Comparisons -extern const PrimitivePtr kPrimScalarEq; -extern const PrimitivePtr kPrimScalarLt; -extern const PrimitivePtr kPrimScalarGt; -extern const PrimitivePtr kPrimScalarNe; -extern const PrimitivePtr kPrimScalarLe; -extern const PrimitivePtr kPrimScalarGe; -extern const PrimitivePtr kPrimBoolNot; -extern const PrimitivePtr kPrimBoolAnd; -extern const PrimitivePtr kPrimBoolOr; -extern const PrimitivePtr kPrimBoolEq; -extern const PrimitivePtr kPrimGreater; -extern const PrimitivePtr kPrimGreaterEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimNotEqual; +inline const PrimitivePtr kPrimScalarEq = std::make_shared("scalar_eq"); +inline const PrimitivePtr kPrimScalarLt = std::make_shared("scalar_lt"); +inline const PrimitivePtr kPrimScalarGt = std::make_shared("scalar_gt"); +inline const PrimitivePtr kPrimScalarNe = std::make_shared("scalar_ne"); +inline const PrimitivePtr kPrimScalarLe = std::make_shared("scalar_le"); +inline const PrimitivePtr kPrimScalarGe = std::make_shared("scalar_ge"); +inline const PrimitivePtr kPrimBoolNot = std::make_shared("bool_not"); +inline const PrimitivePtr kPrimBoolAnd = std::make_shared("bool_and"); +inline const PrimitivePtr kPrimBoolOr = std::make_shared("bool_or"); +inline const PrimitivePtr kPrimBoolEq = std::make_shared("bool_eq"); +inline const PrimitivePtr kPrimGreater = std::make_shared("Greater"); +inline const PrimitivePtr kPrimGreaterEqual = std::make_shared("GreaterEqual"); +inline const PrimitivePtr kPrimLess = std::make_shared("Less"); +inline const PrimitivePtr kPrimLessEqual = std::make_shared("LessEqual"); +inline const PrimitivePtr kPrimEqual = std::make_shared("Equal"); +inline const PrimitivePtr kPrimNotEqual = std::make_shared("NotEqual"); // Type introspection -extern const PrimitivePtr kPrimTypeOf; -extern const PrimitivePtr kPrimHasType; +inline const PrimitivePtr kPrimTypeOf = std::make_shared("typeof"); +inline const PrimitivePtr kPrimHasType = std::make_shared("hastype"); -// Statements -extern const PrimitivePtr kPrimSwitch; -extern const PrimitivePtr kPrimSwitchLayer; -extern const PrimitivePtr kPrimReturn; -extern const PrimitivePtr kPrimAssign; -extern const PrimitivePtr kPrimAssignAdd; -extern const PrimitivePtr kPrimAssignSub; -extern const PrimitivePtr kPrimSelect; -extern const PrimitivePtr kPrimCall; +inline const PrimitivePtr kPrimDistribute = std::make_shared("distribute"); +inline const PrimitivePtr kPrimDot = std::make_shared("dot"); +inline const PrimitivePtr kPrimIm2Col = std::make_shared("im2col"); +inline const PrimitivePtr kPrimCol2Im = std::make_shared("col2im"); +inline const PrimitivePtr kPrimIm2ColV1 = std::make_shared("im2col_v1"); +inline const PrimitivePtr kPrimCol2ImV1 = std::make_shared("col2im_v1"); -extern const PrimitivePtr kPrimDistribute; -extern const PrimitivePtr kPrimDot; -extern const PrimitivePtr kPrimIm2Col; -extern const PrimitivePtr kPrimCol2Im; -extern const PrimitivePtr kPrimIm2ColV1; -extern const PrimitivePtr kPrimCol2ImV1; +inline const PrimitivePtr kPrimResolve = std::make_shared("resolve"); +inline const PrimitivePtr kPrimEmbed = std::make_shared("embed"); +inline const PrimitivePtr kPrimRefToEmbed = std::make_shared("RefToEmbed"); +inline const PrimitivePtr kPrimCreateInstance = std::make_shared("create_instance"); -extern const PrimitivePtr kPrimResolve; -extern const PrimitivePtr kPrimEmbed; -extern const PrimitivePtr kPrimRefToEmbed; -extern const PrimitivePtr kPrimCreateInstance; - -extern const PrimitivePtr kPrimLabelGoto; -extern const PrimitivePtr kPrimLabelSwitch; -extern const PrimitivePtr kPrimLabelSet; - -// Structure -extern const PrimitivePtr kPrimStringEqual; -extern const PrimitivePtr kPrimStringConcat; -extern const PrimitivePtr kPrimMakeTuple; -extern const PrimitivePtr kPrimMakeList; -extern const PrimitivePtr kPrimMakeDict; -extern const PrimitivePtr kPrimMakeKeywordArg; -extern const PrimitivePtr kPrimExtractKeywordArg; -extern const PrimitivePtr kPrimMakeSlice; -extern const PrimitivePtr kPrimMakeRecord; -extern const PrimitivePtr kPrimTupleGetItem; -extern const PrimitivePtr kPrimListGetItem; -extern const PrimitivePtr kPrimArrayGetItem; -extern const PrimitivePtr kPrimTupleSetItem; -extern const PrimitivePtr kPrimListSetItem; -extern const PrimitivePtr kPrimArraySetItem; -extern const PrimitivePtr kPrimDictGetItem; -extern const PrimitivePtr kPrimDictSetItem; -extern const PrimitivePtr kPrimListAppend; -extern const PrimitivePtr kPrimGetAttr; -extern const PrimitivePtr kPrimTupleLen; -extern const PrimitivePtr kPrimDictLen; -extern const PrimitivePtr kPrimListLen; -extern const PrimitivePtr kPrimArrayLen; -extern const PrimitivePtr kPrimListMap; -extern const PrimitivePtr kPrimListReduce; -extern const PrimitivePtr kPrimTupleReversed; -extern const PrimitivePtr kPrimTileShape; -extern const PrimitivePtr kPrimReducedShape; -extern const PrimitivePtr kPrimTupleDiv; -extern const PrimitivePtr kPrimTupleToArray; -extern const PrimitivePtr kPrimShapeMul; -extern const PrimitivePtr kPrimGenerateShapeIndex; -extern const PrimitivePtr kPrimGenerateInverseIndex; -extern const PrimitivePtr kPrimTupleEqual; -extern const PrimitivePtr kPrimListEqual; -extern const PrimitivePtr kPrimMakeRange; -extern const PrimitivePtr kPrimStopGradient; +inline const PrimitivePtr kPrimLabelGoto = std::make_shared("LabelGoto"); +inline const PrimitivePtr kPrimLabelSwitch = std::make_shared("LabelSwitch"); +inline const PrimitivePtr kPrimLabelSet = std::make_shared("LabelSet"); // Arrays -extern const PrimitivePtr kPrimScalarToArray; -extern const PrimitivePtr kPrimArrayToScalar; -extern const PrimitivePtr kPrimBroadcastShape; -extern const PrimitivePtr kPrimArrayMap; -extern const PrimitivePtr kPrimArrayReduce; -extern const PrimitivePtr kPrimShape; -extern const PrimitivePtr kPrimCast; -extern const PrimitivePtr kPrimConcat; -extern const PrimitivePtr kPrimSqueeze; -extern const PrimitivePtr kPrimTranspose; -extern const PrimitivePtr kPrimGatherV2; -extern const PrimitivePtr kPrimEmbeddingLookup; -extern const PrimitivePtr kPrimEmbeddingLookupCommGrad; -extern const PrimitivePtr kPrimSize; -extern const PrimitivePtr kPrimArgMax; -extern const PrimitivePtr kPrimPack; -extern const PrimitivePtr kPrimUnpack; -extern const PrimitivePtr kPrimUnsortedSegmentMin; -extern const PrimitivePtr kPrimUnsortedSegmentSum; -extern const PrimitivePtr kPrimConcatOffset; -extern const PrimitivePtr kPrimReshape; -extern const PrimitivePtr kPrimTile; -extern const PrimitivePtr kPrimAddN; -extern const PrimitivePtr KPrimTransData; -extern const PrimitivePtr kPrimNMSWithMask; -extern const PrimitivePtr kPrimPad; -extern const PrimitivePtr kPrimArgMaxWithValue; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimReciprocal; -extern const PrimitivePtr kPrimExpandDims; - -// Maths -extern const PrimitivePtr kPrimTensorAdd; -extern const PrimitivePtr kPrimMatMul; -extern const PrimitivePtr kPrimBatchMatMul; -extern const PrimitivePtr kPrimMaximumGrad; -extern const PrimitivePtr kPrimMinimumGrad; -extern const PrimitivePtr kPrimReduceMean; -extern const PrimitivePtr kPrimReduceSum; -extern const PrimitivePtr kPrimReduceAll; -extern const PrimitivePtr kPrimReduceMax; -extern const PrimitivePtr kPrimReduceMin; -extern const PrimitivePtr kPrimNeg; -extern const PrimitivePtr kPrimSub; -extern const PrimitivePtr kPrimMul; -extern const PrimitivePtr kPrimRealDiv; -extern const PrimitivePtr kPrimMinimum; -extern const PrimitivePtr kPrimMaximum; -extern const PrimitivePtr kPrimSquare; -extern const PrimitivePtr kPrimSqrt; -extern const PrimitivePtr kPrimEqual; -extern const PrimitivePtr kPrimLess; -extern const PrimitivePtr kPrimLessEqual; -extern const PrimitivePtr kPrimCumSum; -extern const PrimitivePtr kPrimCumProd; -extern const PrimitivePtr kPrimSubscalar; -extern const PrimitivePtr kPrimInplaceAdd; -extern const PrimitivePtr kPrimInplaceSub; -extern const PrimitivePtr kPrimPow; +inline const PrimitivePtr kPrimScalarToArray = std::make_shared("scalar_to_array"); +inline const PrimitivePtr kPrimArrayToScalar = std::make_shared("array_to_scalar"); +inline const PrimitivePtr kPrimBroadcastShape = std::make_shared("broadcast_shape"); +inline const PrimitivePtr kPrimArrayMap = std::make_shared("array_map"); +inline const PrimitivePtr kPrimArrayReduce = std::make_shared("array_reduce"); +inline const PrimitivePtr kPrimShape = std::make_shared("Shape"); +inline const PrimitivePtr kPrimCast = std::make_shared("Cast"); +inline const PrimitivePtr kPrimConcat = std::make_shared("Concat"); +inline const PrimitivePtr kPrimSqueeze = std::make_shared("Squeeze"); +inline const PrimitivePtr kPrimTranspose = std::make_shared("Transpose"); +inline const PrimitivePtr kPrimGatherV2 = std::make_shared("GatherV2"); +inline const PrimitivePtr kPrimEmbeddingLookup = std::make_shared("EmbeddingLookup"); +inline const PrimitivePtr kPrimEmbeddingLookupCommGrad = std::make_shared("EmbeddingLookupCommGrad"); +inline const PrimitivePtr kPrimSize = std::make_shared("Size"); +inline const PrimitivePtr kPrimArgMax = std::make_shared("Argmax"); +inline const PrimitivePtr kPrimPack = std::make_shared("Pack"); +inline const PrimitivePtr kPrimUnsortedSegmentSum = std::make_shared("UnsortedSegmentSum"); +inline const PrimitivePtr kPrimUnsortedSegmentMin = std::make_shared("UnsortedSegmentMin"); +inline const PrimitivePtr kPrimConcatOffset = std::make_shared("ConcatOffset"); +inline const PrimitivePtr kPrimReshape = std::make_shared("Reshape"); +inline const PrimitivePtr kPrimTile = std::make_shared("Tile"); +inline const PrimitivePtr kPrimAddN = std::make_shared("AddN"); +inline const PrimitivePtr KPrimTransData = std::make_shared("TransData"); +inline const PrimitivePtr kPrimNMSWithMask = std::make_shared("NMSWithMask"); +inline const PrimitivePtr kPrimPad = std::make_shared("Pad"); +inline const PrimitivePtr kPrimArgMaxWithValue = std::make_shared("ArgMaxWithValue"); // NN -extern const PrimitivePtr kPrimFlatten; -extern const PrimitivePtr kPrimSoftmax; -extern const PrimitivePtr kPrimLogSoftmax; -extern const PrimitivePtr kPrimLogSoftmaxGrad; -extern const PrimitivePtr kPrimApplyCenteredRMSProp; -extern const PrimitivePtr kPrimTanh; -extern const PrimitivePtr kPrimTanhGrad; -extern const PrimitivePtr kPrimPooling; -extern const PrimitivePtr kPrimPoolingGrad; -extern const PrimitivePtr kPrimFusedBatchNorm; -extern const PrimitivePtr kPrimBatchNorm; -extern const PrimitivePtr kPrimBatchNormGrad; -extern const PrimitivePtr kPrimConv2D; -extern const PrimitivePtr kPrimMaxPool; -extern const PrimitivePtr kPrimMaxPoolGrad; -extern const PrimitivePtr kPrimAvgPoolGrad; -extern const PrimitivePtr kPrimFusedBatchNormGrad; -extern const PrimitivePtr kPrimReluGrad; -extern const PrimitivePtr kPrimConv2DBackpropInput; -extern const PrimitivePtr kPrimConv2DBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNative; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter; -extern const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput; - -extern const PrimitivePtr kPrimBiasAddGrad; -extern const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits; -extern const PrimitivePtr kPrimMomentum; -extern const PrimitivePtr kPrimApplyMomentum; -extern const PrimitivePtr kPrimLayerNorm; -extern const PrimitivePtr kPrimLayerNormGrad; -extern const PrimitivePtr kPrimLayerNormXBackprop; -extern const PrimitivePtr kPrimLayerNormBetaGammaBackprop; -extern const PrimitivePtr kPrimDropoutGenMask; -extern const PrimitivePtr kPrimDropoutDoMask; -extern const PrimitivePtr kPrimOneHot; -extern const PrimitivePtr kPrimGelu; -extern const PrimitivePtr kPrimGeluGrad; -extern const PrimitivePtr kPrimRelu; -extern const PrimitivePtr kPrimReluV2; -extern const PrimitivePtr kPrimActivation; -extern const PrimitivePtr kPrimZerosLike; -extern const PrimitivePtr kPrimFakeBprop; -extern const PrimitivePtr kPrimBpropCut; -extern const PrimitivePtr kPrimFakeQuantPerLayer; -extern const PrimitivePtr kPrimFakeQuantPerChannel; -extern const PrimitivePtr kPrimApplyRMSProp; - -// Other Miscellaneous -extern const PrimitivePtr kPrimIdentity; -extern const PrimitivePtr kPrimPartial; -extern const PrimitivePtr kPrimJ; -extern const PrimitivePtr kPrimEnvSetItem; -extern const PrimitivePtr kPrimEnvGetItem; -extern const PrimitivePtr kPrimEnvAdd; -extern const PrimitivePtr kPrimMakeRefKey; -extern const PrimitivePtr kPrimMakeRef; -extern const PrimitivePtr kPrimGetRefKey; -extern const PrimitivePtr kPrimGetRefValue; -extern const PrimitivePtr kPrimGetRefOrigin; -extern const PrimitivePtr kPrimInsertGradientOf; -extern const PrimitivePtr kPrimHookBackward; -extern const PrimitivePtr kPrimPrintShapeType; -extern const PrimitivePtr kPrimPrint; -extern const PrimitivePtr kPrimSameTypeShape; -extern const PrimitivePtr kPrimCheckBprop; -extern const PrimitivePtr kPrimDepend; -extern const PrimitivePtr kPrimStateSetItem; -extern const PrimitivePtr kPrimScalarSummary; -extern const PrimitivePtr kPrimImageSummary; -extern const PrimitivePtr kPrimTensorSummary; -extern const PrimitivePtr kPrimHistogramSummary; -extern const PrimitivePtr kPrimBroadcastGradientArgs; -extern const PrimitivePtr kPrimControlDepend; -extern const PrimitivePtr kPrimIs_; -extern const PrimitivePtr kPrimIsNot; -extern const PrimitivePtr kPrimInDict; -extern const PrimitivePtr kPrimNotInDict; -extern const PrimitivePtr kPrimMixedPrecisionCast; -extern const PrimitivePtr kPrimIsConsant; -extern const PrimitivePtr kPrimEquivFormat; -extern const PrimitivePtr kPrimDebug; +inline const PrimitivePtr kPrimFlatten = std::make_shared("Flatten"); +inline const PrimitivePtr kPrimSoftmax = std::make_shared("Softmax"); +inline const PrimitivePtr kPrimLogSoftmax = std::make_shared("LogSoftmax"); +inline const PrimitivePtr kPrimLogSoftmaxGrad = std::make_shared("LogSoftmaxGrad"); +inline const PrimitivePtr kPrimTanh = std::make_shared("Tanh"); +inline const PrimitivePtr kPrimTanhGrad = std::make_shared("TanhGrad"); +inline const PrimitivePtr kPrimPooling = std::make_shared("Pooling"); +inline const PrimitivePtr kPrimPoolingGrad = std::make_shared("PoolingGrad"); +inline const PrimitivePtr kPrimMaxPool = std::make_shared("MaxPool"); +inline const PrimitivePtr kPrimMaxPoolGrad = std::make_shared("MaxPoolGrad"); +inline const PrimitivePtr kPrimApplyCenteredRMSProp = std::make_shared("ApplyCenteredRMSProp"); +inline const PrimitivePtr kPrimAvgPoolGrad = std::make_shared("AvgPoolGrad"); +inline const PrimitivePtr kPrimFusedBatchNorm = std::make_shared("FusedBatchNorm"); +inline const PrimitivePtr kPrimConv2D = std::make_shared("Conv2D"); +inline const PrimitivePtr kPrimFusedBatchNormGrad = std::make_shared("FusedBatchNormGrad"); +inline const PrimitivePtr kPrimBatchNorm = std::make_shared("BatchNorm"); +inline const PrimitivePtr kPrimBatchNormGrad = std::make_shared("BatchNormGrad"); +inline const PrimitivePtr kPrimReluGrad = std::make_shared("ReluGrad"); +inline const PrimitivePtr kPrimConv2DBackpropInput = std::make_shared("Conv2DBackpropInput"); +inline const PrimitivePtr kPrimConv2DBackpropFilter = std::make_shared("Conv2DBackpropFilter"); +inline const PrimitivePtr kPrimDepthwiseConv2dNative = std::make_shared("DepthwiseConv2dNative"); +inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropFilter = + std::make_shared("DepthwiseConv2dNativeBackpropFilter"); +inline const PrimitivePtr kPrimDepthwiseConv2dNativeBackpropInput = + std::make_shared("DepthwiseConv2dNativeBackpropInput"); +inline const PrimitivePtr kPrimBiasAddGrad = std::make_shared("BiasAddGrad"); +inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits = + std::make_shared("SoftmaxCrossEntropyWithLogits"); +inline const PrimitivePtr kPrimSparseSoftmaxCrossEntropyWithLogits = + std::make_shared("SparseSoftmaxCrossEntropyWithLogits"); +inline const PrimitivePtr kPrimMomentum = std::make_shared("Momentum"); +inline const PrimitivePtr kPrimApplyMomentum = std::make_shared("ApplyMomentum"); +inline const PrimitivePtr kPrimLayerNorm = std::make_shared("LayerNorm"); +inline const PrimitivePtr kPrimLayerNormGrad = std::make_shared("LayerNormGrad"); +inline const PrimitivePtr kPrimLayerNormXBackprop = std::make_shared("LayerNormXBackprop"); +inline const PrimitivePtr kPrimLayerNormBetaGammaBackprop = std::make_shared("LayerNormBetaGammaBackprop"); +inline const PrimitivePtr kPrimDropoutGenMask = std::make_shared("DropoutGenMask"); +inline const PrimitivePtr kPrimDropoutDoMask = std::make_shared("DropoutDoMask"); +inline const PrimitivePtr kPrimOneHot = std::make_shared("OneHot"); +inline const PrimitivePtr kPrimGelu = std::make_shared("Gelu"); +inline const PrimitivePtr kPrimGeluGrad = std::make_shared("GeluGrad"); +inline const PrimitivePtr kPrimRelu = std::make_shared("ReLU"); +inline const PrimitivePtr kPrimReluV2 = std::make_shared("ReLUV2"); +inline const PrimitivePtr kPrimZerosLike = std::make_shared("ZerosLike"); +inline const PrimitivePtr kPrimFakeBprop = std::make_shared("fake_bprop"); +inline const PrimitivePtr kPrimBpropCut = std::make_shared("bprop_cut"); +inline const PrimitivePtr kPrimFakeQuantPerLayer = std::make_shared("FakeQuantPerLayer"); +inline const PrimitivePtr kPrimFakeQuantPerChannel = std::make_shared("FakeQuantPerChannel"); +inline const PrimitivePtr kPrimApplyRMSProp = std::make_shared("ApplyRMSProp"); // Comm ops -extern const PrimitivePtr kPrimAllReduce; -extern const PrimitivePtr kPrimMirror; -extern const PrimitivePtr kPrimVirtualDiv; -extern const PrimitivePtr kPrimVirtualDataset; +inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); +inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); +inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); +inline const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); // IndexedSlices -extern const PrimitivePtr kPrimMakeIndexedSlices; -extern const PrimitivePtr kPrimIndexedSlicesGetValues; -extern const PrimitivePtr kPrimIndexedSlicesGetIndices; -extern const PrimitivePtr kPrimIndexedSlicesGetDenseShape; +inline const PrimitivePtr kPrimMakeIndexedSlices = std::make_shared("MakeIndexedSlices"); +inline const PrimitivePtr kPrimIndexedSlicesGetValues = std::make_shared("IndexedSlicesGetValues"); +inline const PrimitivePtr kPrimIndexedSlicesGetIndices = std::make_shared("IndexedSlicesGetIndices"); +inline const PrimitivePtr kPrimIndexedSlicesGetDenseShape = std::make_shared("IndexedSlicesGetDenseShape"); +inline const PrimitivePtr kPrimIsIndexedSlices = std::make_shared("IsIndexedSlices"); // SparseTensor -extern const PrimitivePtr kPrimMakeSparseTensor; -extern const PrimitivePtr kPrimSparseTensorGetValues; -extern const PrimitivePtr kPrimSparseTensorGetIndices; -extern const PrimitivePtr kPrimSparseTensorGetDenseShape; +inline const PrimitivePtr kPrimMakeSparseTensor = std::make_shared("MakeSparseTensor"); +inline const PrimitivePtr kPrimSparseTensorGetValues = std::make_shared("SparseTensorGetValues"); +inline const PrimitivePtr kPrimSparseTensorGetIndices = std::make_shared("SparseTensorGetIndices"); +inline const PrimitivePtr kPrimSparseTensorGetDenseShape = std::make_shared("SparseTensorGetDenseShape"); // attribute 'unroll_flag' of primitive 'switch', when 'unroll_flag' is '0', 'switch' will not unroll const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; @@ -305,22 +191,6 @@ const char SWITCH_UNROLL_FLAG[] = "unroll_flag"; // will be sunk(i.e. not unrolled) const int MAX_FOR_LOOP_COUNT = 600; -class DoSignaturePrimitive : public Primitive { - public: - explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) - : Primitive("S-Prim-" + name), function_(function) {} - - ~DoSignaturePrimitive() override = default; - - MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) - - const ValuePtr function() const { return function_; } - - private: - ValuePtr function_; -}; -using DoSignaturePrimitivePtr = std::shared_ptr; - class UnpackGraphPrimitive : public Primitive { public: explicit UnpackGraphPrimitive(const std::string &name, const bool &with_sens, const bool &need_unpack_args) diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc index 70ae5a7d20..f69bda4100 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_fusion.cc @@ -50,7 +50,7 @@ std::unordered_set FindCNodesWithPara(const AnfNodePtr ¶, uint32_t if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + if (IsParallelCareNode(cnode) && cnode->HasUserData()) { (void)cnode_set.emplace(cnode); } else { auto cnode_set_sub = FindCNodesWithPara(node_pair.first, recursive_times + 1); @@ -98,11 +98,12 @@ CNodeCostMap AllreduceFusion::FindCNode(const AnfNodePtr &from, uint32_t recursi return cnode_dist; } + auto operator_info = cnode->GetUserData(); MS_LOG(DEBUG) << "cnode " << cnode->ToString() << " IsParallelCareNode: " << IsParallelCareNode(cnode) - << " operator_info: " << (cnode->operator_info() != nullptr); + << " operator_info: " << (operator_info != nullptr); - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - auto cost = cnode->operator_info()->GetForwardMemoryCostFromCNode(); + if (IsParallelCareNode(cnode) && (operator_info != nullptr)) { + auto cost = operator_info->GetForwardMemoryCostFromCNode(); MS_LOG(DEBUG) << "cnode " << cnode->DebugString() << " cost: " << cost; if (allreduce_graph_.NodeInGraph(cnode)) { diff --git a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc index 1c478887df..93f01943bf 100644 --- a/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc +++ b/mindspore/ccsrc/frontend/parallel/allreduce_fusion/allreduce_node.cc @@ -83,7 +83,7 @@ Status AllreduceNode::AddPara(const AnfNodePtr &node_ptr) { } auto para_ptr = node_ptr->cast(); MS_EXCEPTION_IF_NULL(para_ptr); - auto layout_ptr = para_ptr->tensor_layout(); + auto layout_ptr = para_ptr->GetUserData(); if (layout_ptr == nullptr) { MS_LOG(ERROR) << "layout_ptr is nullptr!"; return FAILED; diff --git a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc index effbdc17c7..bdcbe9fd3d 100644 --- a/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc +++ b/mindspore/ccsrc/frontend/parallel/graph_util/get_parallel_info.cc @@ -37,7 +37,7 @@ py::dict GetParameterLayout(const FuncGraphPtr &graph) { for (auto para : graph_params) { std::string name = std::static_pointer_cast(para)->name(); - std::shared_ptr tensor_layout = std::static_pointer_cast(para)->tensor_layout(); + auto tensor_layout = para->GetUserData(); if (tensor_layout == nullptr) { MS_LOG(INFO) << "GetParameterLayout nullptr name = " << name; } else { @@ -70,7 +70,7 @@ py::dict GetCNodeStrategy(const FuncGraphPtr &graph) { if (node->isa()) { auto cnode = node->cast(); MS_EXCEPTION_IF_NULL(cnode); - auto distributed_operation_info = cnode->operator_info(); + auto distributed_operation_info = cnode->GetUserData(); if (distributed_operation_info != nullptr) { auto strategyPtr = distributed_operation_info->strategy(); if (strategyPtr != nullptr) { diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h index bc1c475b93..49e44ef347 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/operator_info.h @@ -163,6 +163,9 @@ class OperatorInfo { const std::string &type() const { return type_; } const std::unordered_map &attrs() const { return attrs_; } + // Key for user data. + constexpr static char key[] = "OpInfo"; + protected: // needed by rec_parser std::string type_; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8d54eb454a..036b8d250c 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -435,7 +435,7 @@ Status ConstructCostGraphNodesByUniqueId(const std::vector &all_node std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); + cnode->SetUserData(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -501,7 +501,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no std::vector inputs_tensor_name = ExtractInputsTensorName(cnode); entire_costgraph->AddOperator(operator_info); - (void)cnode->set_operator_info(operator_info); + cnode->SetUserData(operator_info); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << operator_info->name() << ", Primitive: " << prim->name(); @@ -520,7 +520,7 @@ Status ConstructCostGraphNodesByUniqueIdTC(const std::vector &all_no MS_LOG(EXCEPTION) << "The OperatorInfo: " << current_op_ptr->name() << " does not match the Prim: " << prim->name(); } - (void)cnode->set_operator_info(current_op_ptr); + cnode->SetUserData(current_op_ptr); MS_LOG(INFO) << "The CNode with UniqueId: " << cnode->UniqueId() << " and UniqueIdThroughCopy: " << cnode->UniqueIdThroughCopy() << " is set OperatorInfo: " << current_op_ptr->name() << ", Primitive: " << prim->name(); @@ -549,6 +549,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { PrimitivePtr prim = GetValueNode(prim_anf_node); size_t edge_count = 0; + auto node_op_info = cnode->GetUserData(); + for (size_t i = 1; i < inputs.size(); ++i) { auto prev_cnode = inputs[i]->cast(); bool bool_result_prev_cnode = (prev_cnode == nullptr) || (!IsValueNode(prev_cnode->input(0))); @@ -563,8 +565,8 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); while (bool_result) { if (IsAutoParallelCareNode(prev_cnode)) { - std::string edge_name = - prev_cnode->operator_info()->name() + OPERATOR_TO_OPERATOR_CONNECTOR + cnode->operator_info()->name(); + auto prev_op_info = prev_cnode->GetUserData(); + std::string edge_name = prev_op_info->name() + OPERATOR_TO_OPERATOR_CONNECTOR + node_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. if (entire_costgraph->IsEdgeInCostGraph(edge_name, output_index, i - 1)) { break; @@ -577,22 +579,20 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { if (follow_strategy) { // Redistribution in not allowed on the edge. // Elementwise operators have the same strategy as their previous operators. - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false, true); + edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false, true); } else { - edge_ptr = std::make_shared(edge_name, prev_cnode->operator_info(), cnode->operator_info(), - output_index, i - 1, false); + edge_ptr = std::make_shared(edge_name, prev_op_info, node_op_info, output_index, i - 1, false); } // Init costs for this edge if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; } - cnode->operator_info()->AddPrevEdge(edge_ptr); - prev_cnode->operator_info()->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(prev_cnode->operator_info(), cnode->operator_info(), edge_ptr); - MS_LOG(INFO) << "Successfully adding the edge between " << prev_cnode->operator_info()->name() << " and " - << cnode->operator_info()->name(); + node_op_info->AddPrevEdge(edge_ptr); + prev_op_info->AddSuccEdge(edge_ptr); + entire_costgraph->AddEdge(prev_op_info, node_op_info, edge_ptr); + MS_LOG(INFO) << "Successfully adding the edge between " << prev_op_info->name() << " and " + << node_op_info->name(); edge_count++; break; @@ -633,7 +633,7 @@ void ConstructCostGraphEdges(const std::vector &all_nodes) { (IsAutoParallelCareNode(prev_cnode)) || (prev_prim->name() == TUPLE_GETITEM) || (prev_prim->name() == DEPEND); } } - MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << cnode->operator_info()->name(); + MS_LOG(INFO) << "Successfully created " << edge_count << " edges for: " << node_op_info->name(); } MS_LOG(INFO) << "Constructing edges for cost graph ends."; @@ -750,7 +750,8 @@ void AugmentCostGraph(const std::vector &all_nodes) { for (auto &target : target_set) { auto target_cnode = target.first->cast(); auto input_index = target.second; - (void)target_without_duplicate.insert(std::to_string(input_index) + target_cnode->operator_info()->name()); + (void)target_without_duplicate.insert(std::to_string(input_index) + + target_cnode->GetUserData()->name()); } if (target_without_duplicate.size() <= 1) { continue; @@ -830,24 +831,24 @@ void AugmentCostGraph(const std::vector &all_nodes) { auto target_cnode = target.first->cast(); auto prim = GetValueNode(target_cnode->input(0)); auto input_index = target.second; + auto target_op_info = target_cnode->GetUserData(); - std::string edge_name = - std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_cnode->operator_info()->name(); + std::string edge_name = std::string(IDENTITY_INFO) + OPERATOR_TO_OPERATOR_CONNECTOR + target_op_info->name(); // If the edge between these two operators already has been added, then the edge will not be added again. if (entire_costgraph->IsEdgeInCostGraph(edge_name, 0, IntToSize(input_index - 1))) { continue; } - std::shared_ptr edge_ptr = std::make_shared( - edge_name, tmp_identity_ptr, target_cnode->operator_info(), 0, input_index - 1, false, true); + std::shared_ptr edge_ptr = + std::make_shared(edge_name, tmp_identity_ptr, target_op_info, 0, input_index - 1, false, true); if (edge_ptr->InitEdgeCost() != SUCCESS) { MS_LOG(EXCEPTION) << "Edge cost initialization failed"; } - target_cnode->operator_info()->AddPrevEdge(edge_ptr); + target_op_info->AddPrevEdge(edge_ptr); tmp_identity_ptr->AddSuccEdge(edge_ptr); - entire_costgraph->AddEdge(tmp_identity_ptr, target_cnode->operator_info(), edge_ptr); + entire_costgraph->AddEdge(tmp_identity_ptr, target_op_info, edge_ptr); MS_LOG(INFO) << "Successfully adding the edge between " << tmp_identity_ptr->name() << " and " - << target_cnode->operator_info()->name(); + << target_op_info->name(); add_identity_edge = true; } if (new_identity && add_identity_edge) { @@ -861,20 +862,13 @@ bool FindReshape(const CNodePtr &cnode) { if ((cnode == nullptr) || !IsValueNode(cnode->input(0))) { return false; } - ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + if (!IsParallelCareNode(cnode) || !cnode->HasUserData()) { return false; } + ValueNodePtr prim_anf_node = cnode->input(0)->cast(); PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); - if (operator_info == nullptr) { - MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; - } - if (prim->name() != RESHAPE) { - return false; - } - return true; + return (prim->name() == RESHAPE); } // find previous node, then obtain its strategy_cost_ vector to get its layout vector. @@ -890,8 +884,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ if (!IsValueNode(cnode->input(0))) { return false; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { - *pre_operator_info = cnode->operator_info(); + auto node_op_info = cnode->GetUserData(); + if (IsParallelCareNode(cnode) && (node_op_info != nullptr)) { + *pre_operator_info = node_op_info; *out_index = 0; return true; } @@ -905,8 +900,9 @@ bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_ MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode"; } CNodePtr pre_cnode = pre_node->cast(); - if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) { - *pre_operator_info = pre_cnode->operator_info(); + auto pre_op_info = pre_cnode->GetUserData(); + if (IsParallelCareNode(pre_cnode) && (pre_op_info != nullptr)) { + *pre_operator_info = pre_op_info; return true; } return false; @@ -945,14 +941,15 @@ bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + auto op_info = use_apply->GetUserData(); + if (IsParallelCareNode(use_apply) && (op_info != nullptr)) { MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name(); - *next_operator_info = use_apply->operator_info(); + *next_operator_info = op_info; *in_index = node_pair.second - 1; return true; } MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); + << " " << (op_info != nullptr); if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) { return true; @@ -973,8 +970,8 @@ void ReshapeCostCompute(const std::vector &all_nodes) { int32_t out_index = 0; OperatorInfoPtr pre_operator_info; std::vector> pre_stra_costs; + auto operator_info = cnode->GetUserData(); if (pre_node->isa()) { - OperatorInfoPtr operator_info = cnode->operator_info(); auto reshape_info = std::dynamic_pointer_cast(operator_info); reshape_info->SetCostForReshapeWithParameter(); pre_operator_info = reshape_info; @@ -995,7 +992,6 @@ void ReshapeCostCompute(const std::vector &all_nodes) { } // set input_layout and output_layout for reshape. // init reshape and set cost for each input_layout and output_layout. - OperatorInfoPtr operator_info = cnode->operator_info(); auto reshape_info = std::dynamic_pointer_cast(operator_info); reshape_info->set_pre_operator_name(pre_operator_info->name()); reshape_info->set_pre_operator_index(out_index); diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 20eaf329cf..dec37030c7 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) { if (!IsParallelCareNode(node)) { return nullptr; } - OperatorInfoPtr distribute_operator = node->operator_info(); + OperatorInfoPtr distribute_operator = node->GetUserData(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr"; } @@ -415,7 +415,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) { if (prim->name() == GET_NEXT) { return true; } - if ((prim->name() == CAST) && (cnode->operator_info() == nullptr)) { + if ((prim->name() == CAST) && !cnode->HasUserData()) { return false; } @@ -452,7 +452,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_cnode) && (use_cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData()) { Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution, pre_node); } else { @@ -465,7 +465,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) { MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(next_node); - OperatorInfoPtr op_info = next_node->operator_info(); + OperatorInfoPtr op_info = next_node->GetUserData(); MS_EXCEPTION_IF_NULL(op_info); // If the shape of tensor is [] or [1], no need to split it. @@ -590,7 +590,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) { void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { // step1:get graph manager distribute_operator - OperatorInfoPtr distribute_operator = node->operator_info(); + OperatorInfoPtr distribute_operator = node->GetUserData(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr"; } @@ -628,7 +628,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) { (void)prim->SetAttrs(attrs); } if (index == replace_op.size() - 1) { - (void)replace_node->set_operator_info(node->operator_info()); + replace_node->SetUserData(node->GetUserData()); } replace_node->set_in_forward_flag(true); replace_input[0]->set_scope(scope); @@ -708,7 +708,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) { auto pre_cnode = pre_node->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + if (pre_prim->name() == CAST && !pre_cnode->HasUserData()) { pre_node = pre_cnode->input(1); } @@ -1204,7 +1204,7 @@ std::pair FindParallelCareNode(const AnfNodePtr &node) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(cnode) && cnode->operator_info() != nullptr) { + if (IsParallelCareNode(cnode) && cnode->HasUserData()) { return node_pair; } else if (FindParallelCareNode(node_pair.first).first != nullptr) { return FindParallelCareNode(node_pair.first); @@ -1254,7 +1254,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pairToString() << " shape " << parameter->Shape()->ToString(); CNodePtr cnode = res.first->cast(); MS_EXCEPTION_IF_NULL(cnode); - OperatorInfoPtr distribute_operator = cnode->operator_info(); + OperatorInfoPtr distribute_operator = cnode->GetUserData(); if (distribute_operator == nullptr) { MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr"; } @@ -1277,7 +1277,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::paircast(); MS_EXCEPTION_IF_NULL(parameter_ptr); - parameter_ptr->set_tensor_layout(std::make_shared(tensor_layout)); + parameter_ptr->SetUserData(std::make_shared(tensor_layout)); } void CoverSliceShape(const FuncGraphPtr &root) { @@ -1365,7 +1365,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (found_be_cloned_parameter) { // set the shape and tensor layout for cloned parameter - cloned_parameter->set_tensor_layout(cloned_from_parameter->tensor_layout()); + cloned_parameter->SetUserData(cloned_from_parameter->GetUserData()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); auto cloned_abstract = cloned_parameter_node->abstract()->Clone(); @@ -1464,7 +1464,7 @@ void ExtractInformation(const std::vector &all_nodes) { (*operator_).set_outputs_dtype(cnode->Type()); (*operator_).set_cnode(cnode); if (prim->name() == RESHAPE) { - (void)cnode->set_operator_info(operator_); + cnode->SetUserData(operator_); continue; } // load strategy checkpoint @@ -1499,7 +1499,7 @@ void ExtractInformation(const std::vector &all_nodes) { if (operator_->Init(strategyPtr) == FAILED) { MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; } - (void)cnode->set_operator_info(operator_); + cnode->SetUserData(operator_); } else { MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr"; } @@ -1542,13 +1542,13 @@ std::shared_ptr FindNextLayout(const CNodePtr &cnode) { if (node_prim->name() == DEPEND && node_pair.second != 1) { continue; } - if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) { + if (IsParallelCareNode(use_apply) && use_apply->HasUserData()) { MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name(); auto layout = GetInputLayoutFromCNode(node_pair); return std::make_shared(layout); } MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply) - << " " << (use_apply->operator_info() != nullptr); + << " " << use_apply->HasUserData(); auto layout_ptr = FindNextLayout(use_apply); if (layout_ptr) { @@ -1580,7 +1580,7 @@ std::shared_ptr FindPrevParallelCareNodeLayout(const AnfNodePtr &n if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(cnode) && cnode->HasUserData()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1624,7 +1624,7 @@ std::shared_ptr FindPrevLayout(const AnfNodePtr &node) { if (!IsValueNode(cnode->input(0))) { return nullptr; } - if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) { + if (IsParallelCareNode(cnode) && cnode->HasUserData()) { auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0); if (!layout_ptr) { MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed"; @@ -1664,12 +1664,12 @@ void ReshapeInit(const std::vector &all_nodes) { continue; } ValueNodePtr prim_anf_node = cnode->input(0)->cast(); - if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) { + if (!IsParallelCareNode(cnode) || !cnode->HasUserData()) { continue; } PrimitivePtr prim = GetValueNode(prim_anf_node); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); + OperatorInfoPtr operator_info = cnode->GetUserData(); if (operator_info == nullptr) { MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr"; } @@ -1714,7 +1714,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { auto current_prim = GetValueNode(pre_cnode->input(0)); // return -> cast - if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { + if (current_prim->name() == CAST && !pre_cnode->HasUserData()) { pre_cnode = pre_cnode->input(1)->cast(); MS_EXCEPTION_IF_NULL(pre_cnode); current_prim = GetValueNode(pre_cnode->input(0)); @@ -1771,7 +1771,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) { return ret; } - OperatorInfoPtr operator_info = loss_cnode->operator_info(); + OperatorInfoPtr operator_info = loss_cnode->GetUserData(); MS_EXCEPTION_IF_NULL(operator_info); TensorInfo loss_grad_tensor_info; size_t op_output_size = operator_info->outputs_tensor_info().size(); @@ -1809,7 +1809,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay if (sens_tensor_node->isa()) { auto sens_tensor_param = sens_tensor_node->cast(); MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + sens_tensor_param->SetUserData(std::make_shared(loss_grad_layout)); } MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens"; return; @@ -1834,7 +1834,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay cloned_abstract->set_shape(parallel_shape); sens_tensor_node->set_abstract(cloned_abstract); auto sens_tensor_param = sens_tensor_node->cast(); - sens_tensor_param->set_tensor_layout(std::make_shared(loss_grad_layout)); + sens_tensor_param->SetUserData(std::make_shared(loss_grad_layout)); return; } MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now."; @@ -2125,7 +2125,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) { } PrimitivePtr prim = GetValueNode(cnode->input(0)); MS_EXCEPTION_IF_NULL(prim); - OperatorInfoPtr operator_info = cnode->operator_info(); + OperatorInfoPtr operator_info = cnode->GetUserData(); if (operator_info) { if (operator_info->name().find(RESHAPEINFO) != std::string::npos) { continue; diff --git a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h index 0a0a0db575..73b05ea755 100644 --- a/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h +++ b/mindspore/ccsrc/frontend/parallel/tensor_layout/tensor_layout.h @@ -83,6 +83,9 @@ class TensorLayout { TensorLayout SqueezeShape() const; + // Key for user data. + constexpr static char key[] = "TLayout"; + private: std::shared_ptr ExpandTensorShapeWithoutExtendDeviceArrangement( const Arrangement &expanded_shape) const; diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h new file mode 100755 index 0000000000..e2ef23ca68 --- /dev/null +++ b/mindspore/core/base/core_ops.h @@ -0,0 +1,160 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_OPERATOR_OPS_H_ +#define MINDSPORE_CORE_OPERATOR_OPS_H_ + +#include +#include +#include +#include "ir/anf.h" +#include "ir/primitive.h" + +namespace mindspore { +namespace prim { +// Maths +inline const PrimitivePtr kPrimTensorAdd = std::make_shared("TensorAdd"); +inline const PrimitivePtr kPrimMatMul = std::make_shared("MatMul"); +inline const PrimitivePtr kPrimBatchMatMul = std::make_shared("BatchMatMul"); +inline const PrimitivePtr kPrimMaximumGrad = std::make_shared("MaximumGrad"); +inline const PrimitivePtr kPrimMinimumGrad = std::make_shared("MinimumGrad"); +inline const PrimitivePtr kPrimReduceMean = std::make_shared("ReduceMean"); +inline const PrimitivePtr kPrimReduceSum = std::make_shared("ReduceSum"); +inline const PrimitivePtr kPrimReduceAll = std::make_shared("ReduceAll"); +inline const PrimitivePtr kPrimReduceMax = std::make_shared("ReduceMax"); +inline const PrimitivePtr kPrimReduceMin = std::make_shared("ReduceMin"); +inline const PrimitivePtr kPrimNeg = std::make_shared("Neg"); +inline const PrimitivePtr kPrimSub = std::make_shared("Sub"); +inline const PrimitivePtr kPrimMul = std::make_shared("Mul"); +inline const PrimitivePtr kPrimMinimum = std::make_shared("Minimum"); +inline const PrimitivePtr kPrimMaximum = std::make_shared("Maximum"); +inline const PrimitivePtr kPrimSquare = std::make_shared("Square"); +inline const PrimitivePtr kPrimCumSum = std::make_shared("CumSum"); +inline const PrimitivePtr kPrimCumProd = std::make_shared("CumProd"); +inline const PrimitivePtr kPrimSubscalar = std::make_shared("Subscalar"); +inline const PrimitivePtr kPrimInplaceAdd = std::make_shared("InplaceAdd"); +inline const PrimitivePtr kPrimInplaceSub = std::make_shared("InplaceSub"); +inline const PrimitivePtr kPrimPow = std::make_shared("Pow"); +inline const PrimitivePtr kPrimRealDiv = std::make_shared("RealDiv"); +inline const PrimitivePtr kPrimSqrt = std::make_shared("Sqrt"); +inline const PrimitivePtr kPrimReciprocal = std::make_shared("Reciprocal"); +inline const PrimitivePtr kPrimExpandDims = std::make_shared("ExpandDims"); + +// Statements +inline const PrimitivePtr kPrimReturn = std::make_shared("return"); +inline const PrimitivePtr kPrimSwitch = std::make_shared("switch"); +inline const PrimitivePtr kPrimSwitchLayer = std::make_shared("switch_layer"); +inline const PrimitivePtr kPrimAssign = std::make_shared("Assign"); +inline const PrimitivePtr kPrimAssignAdd = std::make_shared("AssignAdd"); +inline const PrimitivePtr kPrimAssignSub = std::make_shared("AssignSub"); +inline const PrimitivePtr kPrimSelect = std::make_shared("Select"); +inline const PrimitivePtr kPrimCall = std::make_shared("call"); + +// Structures +inline const PrimitivePtr kPrimStringEqual = std::make_shared("string_equal"); +inline const PrimitivePtr kPrimStringConcat = std::make_shared("string_concat"); +inline const PrimitivePtr kPrimMakeTuple = std::make_shared("make_tuple"); +inline const PrimitivePtr kPrimMakeDict = std::make_shared("make_dict"); +inline const PrimitivePtr kPrimMakeList = std::make_shared("make_list"); +inline const PrimitivePtr kPrimMakeKeywordArg = std::make_shared("make_keyword_arg"); +inline const PrimitivePtr kPrimMakeSlice = std::make_shared("make_slice"); +inline const PrimitivePtr kPrimMakeRecord = std::make_shared("make_record"); +inline const PrimitivePtr kPrimTupleGetItem = std::make_shared("tuple_getitem"); +inline const PrimitivePtr kPrimListGetItem = std::make_shared("list_getitem"); +inline const PrimitivePtr kPrimArrayGetItem = std::make_shared("array_getitem"); +inline const PrimitivePtr kPrimTupleSetItem = std::make_shared("tuple_setitem"); +inline const PrimitivePtr kPrimListSetItem = std::make_shared("list_setitem"); +inline const PrimitivePtr kPrimArraySetItem = std::make_shared("array_setitem"); +inline const PrimitivePtr kPrimDictGetItem = std::make_shared("dict_getitem"); +inline const PrimitivePtr kPrimDictSetItem = std::make_shared("dict_setitem"); +inline const PrimitivePtr kPrimListAppend = std::make_shared("list_append"); +inline const PrimitivePtr kPrimGetAttr = std::make_shared("getattr"); +inline const PrimitivePtr kPrimTupleLen = std::make_shared("tuple_len"); +inline const PrimitivePtr kPrimDictLen = std::make_shared("dict_len"); +inline const PrimitivePtr kPrimListLen = std::make_shared("list_len"); +inline const PrimitivePtr kPrimArrayLen = std::make_shared("array_len"); +inline const PrimitivePtr kPrimListMap = std::make_shared("list_map"); +inline const PrimitivePtr kPrimListReduce = std::make_shared("list_reduce"); +inline const PrimitivePtr kPrimTupleReversed = std::make_shared("tuple_reversed"); +inline const PrimitivePtr kPrimTileShape = std::make_shared("tile_shape"); +inline const PrimitivePtr kPrimReducedShape = std::make_shared("reduced_shape"); +inline const PrimitivePtr kPrimTupleDiv = std::make_shared("tuple_div"); +inline const PrimitivePtr kPrimTupleToArray = std::make_shared("tuple_to_array"); +inline const PrimitivePtr kPrimShapeMul = std::make_shared("shape_mul"); +inline const PrimitivePtr kPrimGenerateShapeIndex = std::make_shared("generate_shape_index"); +inline const PrimitivePtr kPrimGenerateInverseIndex = std::make_shared("generate_inverse_index"); +inline const PrimitivePtr kPrimTupleEqual = std::make_shared("tuple_equal"); +inline const PrimitivePtr kPrimListEqual = std::make_shared("list_equal"); +inline const PrimitivePtr kPrimMakeRange = std::make_shared("make_range"); +inline const PrimitivePtr kPrimStopGradient = std::make_shared("stop_gradient"); +inline const PrimitivePtr kPrimExtractKeywordArg = std::make_shared("extract_keyword_arg"); + +// Debug ops +inline const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); +inline const PrimitivePtr kPrimImageSummary = std::make_shared("ImageSummary"); +inline const PrimitivePtr kPrimTensorSummary = std::make_shared("TensorSummary"); +inline const PrimitivePtr kPrimHistogramSummary = std::make_shared("HistogramSummary"); +inline const PrimitivePtr kPrimDebug = std::make_shared("Debug"); + +// Other miscellaneous +inline const PrimitivePtr kPrimJ = std::make_shared("J"); +inline const PrimitivePtr kPrimDepend = std::make_shared("Depend"); +inline const PrimitivePtr kPrimPartial = std::make_shared("Partial"); +inline const PrimitivePtr kPrimIdentity = std::make_shared("identity"); +inline const PrimitivePtr kPrimEnvSetItem = std::make_shared("env_setitem"); +inline const PrimitivePtr kPrimEnvGetItem = std::make_shared("env_getitem"); +inline const PrimitivePtr kPrimEnvAdd = std::make_shared("env_add"); +inline const PrimitivePtr kPrimMakeRefKey = std::make_shared("MakeRefKey"); +inline const PrimitivePtr kPrimGetRefKey = std::make_shared("get_ref_key"); +inline const PrimitivePtr kPrimGetRefValue = std::make_shared("get_ref_value"); +inline const PrimitivePtr kPrimGetRefOrigin = std::make_shared("get_ref_origin"); +inline const PrimitivePtr kPrimInsertGradientOf = std::make_shared("InsertGradientOf"); +inline const PrimitivePtr kPrimHookBackward = std::make_shared("HookBackward"); +inline const PrimitivePtr kPrimPrintShapeType = std::make_shared("PrintShapeType"); +inline const PrimitivePtr kPrimSameTypeShape = std::make_shared("SameTypeShape"); +inline const PrimitivePtr kPrimCheckBprop = std::make_shared("CheckBprop"); +inline const PrimitivePtr kPrimPrint = std::make_shared("Print"); +inline const PrimitivePtr kPrimMakeRef = std::make_shared("make_ref"); +inline const PrimitivePtr kPrimStateSetItem = std::make_shared("state_setitem"); +inline const PrimitivePtr kPrimBroadcastGradientArgs = std::make_shared("BroadcastGradientArgs"); +inline const PrimitivePtr kPrimControlDepend = std::make_shared("ControlDepend"); +inline const PrimitivePtr kPrimIs_ = std::make_shared("is_"); +inline const PrimitivePtr kPrimIsNot = std::make_shared("is_not"); +inline const PrimitivePtr kPrimInDict = std::make_shared("in_dict"); +inline const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); +inline const PrimitivePtr kPrimMixedPrecisionCast = std::make_shared("mixed_precision_cast"); +inline const PrimitivePtr kPrimIsConsant = std::make_shared("is_constant"); +inline const PrimitivePtr kPrimEquivFormat = std::make_shared("EquivFormat"); + +class DoSignaturePrimitive : public Primitive { + public: + explicit DoSignaturePrimitive(const std::string &name, const ValuePtr &function) + : Primitive("S-Prim-" + name), function_(function) {} + + ~DoSignaturePrimitive() override = default; + + MS_DECLARE_PARENT(DoSignaturePrimitive, Primitive) + + const ValuePtr function() const { return function_; } + + private: + ValuePtr function_; +}; +using DoSignaturePrimitivePtr = std::shared_ptr; +} // namespace prim +} // namespace mindspore + +#endif // MINDSPORE_CORE_OPERATOR_OPS_H_ diff --git a/mindspore/core/base/user_data.h b/mindspore/core/base/user_data.h new file mode 100644 index 0000000000..6912d0767d --- /dev/null +++ b/mindspore/core/base/user_data.h @@ -0,0 +1,52 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef MINDSPORE_CORE_USER_DATA_H_ +#define MINDSPORE_CORE_USER_DATA_H_ + +#include +#include +#include + +namespace mindspore { +class UserData { + public: + template + void set(const std::string &key, const std::shared_ptr &value) { + if (value == nullptr) { + data_.erase(key); + } else { + data_.insert_or_assign(key, value); + } + } + + template + std::shared_ptr get(const std::string &key) const { + auto iter = data_.find(key); + if (iter == data_.end()) { + return nullptr; + } + return std::static_pointer_cast(iter->second); + } + + bool has(const std::string &key) const { return data_.find(key) != data_.end(); } + + private: + std::map> data_; +}; +} // namespace mindspore + +#endif // MINDSPORE_CORE_USER_DATA_H_ diff --git a/mindspore/core/ir/anf.cc b/mindspore/core/ir/anf.cc index 275bd3b206..e238012b14 100644 --- a/mindspore/core/ir/anf.cc +++ b/mindspore/core/ir/anf.cc @@ -26,7 +26,6 @@ #include "ir/func_graph.h" #include "ir/primitive.h" #include "utils/context/ms_context.h" -#include "frontend/operator/ops.h" namespace mindspore { // namespace to support intermediate representation definition diff --git a/mindspore/core/ir/anf.h b/mindspore/core/ir/anf.h index a6b2f6bd12..90ef6228a8 100644 --- a/mindspore/core/ir/anf.h +++ b/mindspore/core/ir/anf.h @@ -27,6 +27,7 @@ #include #include "base/base.h" +#include "base/user_data.h" #include "ir/kernel_info_dev.h" #include "ir/scope.h" #include "debug/info.h" @@ -41,12 +42,6 @@ // ANode: Atomic Node // CNode: Complex Node namespace mindspore { -namespace parallel { -class TensorLayout; -class OperatorInfo; -} // namespace parallel -using OperatorInfoPtr = std::shared_ptr; - namespace abstract { class BaseShape; class AbstractBase; @@ -157,6 +152,31 @@ class AnfNode : public Base { } size_t seen_{0}; + template + void SetUserData(const std::string &key, const std::shared_ptr &value) { + user_data_.set(key, value); + } + + template + void SetUserData(const std::shared_ptr &value) { + user_data_.set(T::key, value); + } + + template + std::shared_ptr GetUserData(const std::string &key) const { + return user_data_.get(key); + } + + template + std::shared_ptr GetUserData() const { + return user_data_.get(T::key); + } + + bool HasUserData(const std::string &key) const { return user_data_.has(key); } + + template + bool HasUserData() const { return user_data_.has(T::key); } + protected: // Hold a weak ref to Graph as Graph also hold ref to AnfNode. // Otherwise, func_graph_ and AnfNode will make a reference cycle. @@ -170,6 +190,7 @@ class AnfNode : public Base { std::hash hash_; ScopePtr scope_; KernelInfoDevicePtr kernel_info_; + UserData user_data_; }; // CNode represents the complex node with a set of arguments. @@ -212,9 +233,6 @@ class CNode : public AnfNode { std::string DebugString(int recursive_level = 1) const override; std::string DebugString(bool recursive) const override { return DebugString(recursive ? 1 : 0); } - OperatorInfoPtr set_operator_info(const OperatorInfoPtr &operator_info); - OperatorInfoPtr operator_info() { return operator_info_; } - void set_in_forward_flag(bool flag) { in_forward_flag_ = flag; } bool in_forward_flag() const { return in_forward_flag_; } @@ -224,7 +242,6 @@ class CNode : public AnfNode { std::vector inputs_; VarPtr func_graph_as_var_; bool stop_gradient_; - OperatorInfoPtr operator_info_ = nullptr; bool in_forward_flag_ = false; }; @@ -244,7 +261,7 @@ class ANode : public AnfNode { class Parameter : public ANode { public: explicit Parameter(const FuncGraphPtr &func_graph) - : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr), tensor_layout_(nullptr) {} + : ANode(func_graph), name_(""), has_default_(false), default_param_(nullptr) {} ~Parameter() override = default; MS_DECLARE_PARENT(Parameter, ANode); @@ -261,11 +278,6 @@ class Parameter : public ANode { } ParamValuePtr default_param() const { return default_param_; } - std::shared_ptr tensor_layout() const { return tensor_layout_; } - void set_tensor_layout(const std::shared_ptr &tensor_layout) { - tensor_layout_ = tensor_layout; - } - bool operator==(const AnfNode &other) const override { if (!other.isa()) { return false; @@ -281,7 +293,6 @@ class Parameter : public ANode { std::string name_; bool has_default_; ParamValuePtr default_param_; - std::shared_ptr tensor_layout_; }; using ParameterPtr = std::shared_ptr; diff --git a/mindspore/core/ir/anf_extends.cc b/mindspore/core/ir/anf_extends.cc index b70a660aae..4fb4d6598c 100644 --- a/mindspore/core/ir/anf_extends.cc +++ b/mindspore/core/ir/anf_extends.cc @@ -23,8 +23,7 @@ #include "ir/visitor.h" #include "ir/func_graph.h" -#include "frontend/operator/ops.h" -#include "frontend/parallel/ops_info/ops_utils.h" +#include "base/core_ops.h" #include "debug/label.h" namespace mindspore { @@ -37,18 +36,6 @@ std::string AnfNode::ToString() const { return mindspore::label_manage::Label(const_cast(this)->shared_from_base()->debug_info()); } -OperatorInfoPtr CNode::set_operator_info(const OperatorInfoPtr &operator_info) { - if (operator_info_ != nullptr) { - MS_LOG(WARNING) << "The CNode: " << ToString() << " has already been set OperatorInfo: " << operator_info_->name() - << ", using the new one: " << operator_info->name(); - auto old_ptr = operator_info_; - operator_info_ = operator_info; - return old_ptr; - } - operator_info_ = operator_info; - return nullptr; -} - std::string CNode::fullname_with_scope() { // if full name is set, return its name immediately if (!fullname_with_scope_.empty()) { diff --git a/mindspore/core/ir/func_graph.cc b/mindspore/core/ir/func_graph.cc index 570ed61f96..1ef9d9c6bd 100644 --- a/mindspore/core/ir/func_graph.cc +++ b/mindspore/core/ir/func_graph.cc @@ -24,7 +24,6 @@ #include "debug/trace.h" #include "ir/manager.h" -#include "frontend/operator/ops.h" #include "utils/ordered_set.h" #include "utils/convert_utils_base.h" diff --git a/mindspore/core/ir/func_graph_cloner.cc b/mindspore/core/ir/func_graph_cloner.cc index 432a924b1e..b76d4868ea 100644 --- a/mindspore/core/ir/func_graph_cloner.cc +++ b/mindspore/core/ir/func_graph_cloner.cc @@ -20,7 +20,7 @@ #include "ir/manager.h" #include "ir/param_value.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" #include "utils/convert_utils_base.h" #include "utils/log_adapter.h" #include "utils/profile.h" diff --git a/mindspore/core/ir/func_graph_extends.cc b/mindspore/core/ir/func_graph_extends.cc index 579409b05e..4fa751a32f 100644 --- a/mindspore/core/ir/func_graph_extends.cc +++ b/mindspore/core/ir/func_graph_extends.cc @@ -22,7 +22,7 @@ #include "ir/manager.h" #include "ir/func_graph_cloner.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" #include "utils/ordered_set.h" #include "abstract/abstract_value.h" #include "debug/anf_ir_dump.h" diff --git a/mindspore/core/ir/manager.cc b/mindspore/core/ir/manager.cc index 89ca830092..1b14f9184e 100644 --- a/mindspore/core/ir/manager.cc +++ b/mindspore/core/ir/manager.cc @@ -26,7 +26,7 @@ #include "ir/func_graph.h" #include "utils/profile.h" #include "utils/convert_utils_base.h" -#include "frontend/operator/ops.h" +#include "base/core_ops.h" namespace mindspore { diff --git a/mindspore/core/ir/meta_func_graph.cc b/mindspore/core/ir/meta_func_graph.cc index 7953931e8f..44754798d5 100644 --- a/mindspore/core/ir/meta_func_graph.cc +++ b/mindspore/core/ir/meta_func_graph.cc @@ -17,10 +17,8 @@ */ #include "ir/meta_func_graph.h" -#include "pipeline/jit/static_analysis/static_analysis.h" -#include "pipeline/jit/static_analysis/abstract_function.h" +#include "base/core_ops.h" #include "utils/context/ms_context.h" -#include "frontend/operator/ops.h" // namespace to support intermediate representation definition namespace mindspore { diff --git a/mindspore/core/ir/pattern_matcher.h b/mindspore/core/ir/pattern_matcher.h index 3efc7b98b8..6acfb130c5 100644 --- a/mindspore/core/ir/pattern_matcher.h +++ b/mindspore/core/ir/pattern_matcher.h @@ -22,9 +22,9 @@ #include #include -#include "frontend/operator/ops.h" -#include "frontend/optimizer/optimizer.h" #include "ir/anf.h" +#include "ir/optimizer_caller.h" +#include "base/core_ops.h" namespace mindspore { /// diff --git a/mindspore/core/ir/primitive.h b/mindspore/core/ir/primitive.h index 0dd3a656ee..d950c4642e 100644 --- a/mindspore/core/ir/primitive.h +++ b/mindspore/core/ir/primitive.h @@ -25,7 +25,6 @@ #include "ir/dtype/type.h" #include "abstract/abstract_value.h" -#include "frontend/parallel/ops_info/operator_info.h" #include "utils/base_ref_extends.h" namespace mindspore { diff --git a/mindspore/core/ir/primitive_py.cc b/mindspore/core/ir/primitive_py.cc index a726407d25..f8881f26cd 100644 --- a/mindspore/core/ir/primitive_py.cc +++ b/mindspore/core/ir/primitive_py.cc @@ -18,7 +18,6 @@ #include #include #include "ir/signature.h" -#include "frontend/operator/ops.h" #include "./common.h" #include "pipeline/jit/parse/python_adapter.h" #include "pipeline/jit/parse/data_converter.h" diff --git a/mindspore/core/ir/tensor.cc b/mindspore/core/ir/tensor.cc index c04c2cca96..2fadfa84a9 100644 --- a/mindspore/core/ir/tensor.cc +++ b/mindspore/core/ir/tensor.cc @@ -28,7 +28,6 @@ #include #include -#include "runtime/device/device_address.h" #include "abstract/abstract_value.h" namespace mindspore { diff --git a/tests/ut/cpp/parallel/step_auto_parallel_test.cc b/tests/ut/cpp/parallel/step_auto_parallel_test.cc index 6cf7ec66c6..cca7efd62f 100644 --- a/tests/ut/cpp/parallel/step_auto_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_auto_parallel_test.cc @@ -153,7 +153,7 @@ TEST_F(TestStepAutoParallel, test_create_op_instance) { StrategyPtr strategyPtr; std::shared_ptr matmul_info = NewOperatorInstance(prim, attrs, shape); - node->set_operator_info(matmul_info); + node->SetUserData(matmul_info); std::string name_expect = "MatMulInfo00"; std::string name_test = matmul_info->name(); ASSERT_EQ(name_expect, name_test); diff --git a/tests/ut/cpp/parallel/step_parallel_test.cc b/tests/ut/cpp/parallel/step_parallel_test.cc index 383a061805..dbb541932e 100644 --- a/tests/ut/cpp/parallel/step_parallel_test.cc +++ b/tests/ut/cpp/parallel/step_parallel_test.cc @@ -525,8 +525,8 @@ TEST_F(TestStepParallel, GetTensorInLayout) { std::vector shape = {inputs_shape, outputs_shape}; OperatorInfoPtr matmul_info = OperatorInstance(prim, attrs, shape); matmul_info->Init(strategyPtr); - node->set_operator_info(matmul_info); - OperatorInfoPtr distribute_operator_pre = node->operator_info(); + node->SetUserData(matmul_info); + OperatorInfoPtr distribute_operator_pre = node->GetUserData(); TensorLayout tensorlayout_e; std::vector array = {64, 64}; TensorLayout tensorlayout = GetTensorInLayout(node1, prim, distribute_operator_pre);