diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index 9a85f366b5..5b867c4403 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -230,6 +230,47 @@ constexpr char ASSIGN[] = "Assign"; constexpr char GET_NEXT[] = "GetNext"; constexpr char SQUEEZE[] = "Squeeze"; constexpr char NEG[] = "Neg"; +constexpr char ABS[] = "Abs"; +constexpr char ACOSH[] = "Acosh"; +constexpr char ASIN[] = "Asin"; +constexpr char ASINH[] = "Asinh"; +constexpr char ATAN[] = "Atan"; +constexpr char ATANH[] = "Atanh"; +constexpr char CEIL[] = "Ceil"; +constexpr char COSH[] = "Cosh"; +constexpr char EXPM1[] = "Expm1"; +constexpr char LOG1P[] = "Log1p"; +constexpr char SIN[] = "Sin"; +constexpr char SINH[] = "Sinh"; +constexpr char TAN[] = "Tan"; +constexpr char RSQRT[] = "Rsqrt"; +constexpr char INV[] = "Inv"; +constexpr char RECIPROCAL[] = "Reciprocal"; +constexpr char ROUND[] = "Round"; +constexpr char FLOOR[] = "Floor"; +constexpr char SIGN[] = "Sign"; +constexpr char ERF[] = "Erf"; +constexpr char ERFC[] = "Erfc"; +constexpr char ZEROSLIKE[] = "ZerosLike"; +constexpr char ONESLIKE[] = "OnesLike"; +constexpr char BESSELI0E[] = "BesselI0e"; +constexpr char BESSELI1E[] = "BesselI1e"; +constexpr char FLOORMOD[] = "FloorMod"; +constexpr char ASSIGN_ADD[] = "AssignAdd"; +constexpr char ATAN2[] = "Atan2"; +constexpr char DIVNONAN[] = "DivNoNan"; +constexpr char LOGICALAND[] = "LogicalAnd"; +constexpr char LOGICALOR[] = "LogicalOr"; +constexpr char ELU[] = "Elu"; +constexpr char RELU6[] = "ReLU6"; +constexpr char RELUV2[] = "ReLUV2"; +constexpr char SOFTPLUS[] = "Softplus"; +constexpr char SOFTSIGN[] = "Softsign"; +constexpr char GREATEREQUAL[] = "GreaterEqual"; +constexpr char LESSEQUAL[] = "LessEqual"; +constexpr char LESS[] = "Less"; +constexpr char APPROXIMATEEQUAL[] = "ApproximateEqual"; +constexpr char MOD[] = "Mod"; constexpr char BATCH_MATMUL[] = "BatchMatMul"; constexpr char EXPAND_DIMS[] = "ExpandDims"; constexpr char SQUARE[] = "Square"; @@ -297,7 +338,6 @@ constexpr char COL2IMV1[] = "col2im_v1"; constexpr char RESOLVE[] = "resolve"; constexpr char EMBED[] = "embed"; constexpr char CREATINSTANCE[] = "create_instance"; -constexpr char ZEROSLIKE[] = "ZerosLike"; constexpr char REF_TO_EMBED[] = "RefToEmbed"; constexpr char STOP_GRADIENT[] = "stop_gradient"; diff --git a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc index 8c2b37b327..21fd2a0e5f 100644 --- a/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_auto_parallel.cc @@ -248,9 +248,25 @@ std::vector ExtractOutputTypeByNode(const CNodePtr &node) { } bool IsElementWiseOperator(const std::string &op_name) { - static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, - SQRT, CAST, POW, EXP, LOG, COS, - ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; + static const std::set elementwise_op = {ACTIVATION, GELU, TANH, + SOFTMAX, LOG_SOFTMAX, RELU, + SQRT, CAST, POW, + EXP, LOG, COS, + ACOS, LOGICALNOT, NEG, + SQUARE, SIGMOID, ABS, + ACOSH, ASIN, ASINH, + ATAN, ATANH, CEIL, + COSH, EXPM1, LOG1P, + SIN, SINH, TAN, + RSQRT, RECIPROCAL, INV, + ROUND, FLOOR, SIGN, + ERF, ERFC, ZEROSLIKE, + ONESLIKE, BESSELI0E, MOD, + ASSIGN, ASSIGN_ADD, ATAN2, + DIVNONAN, LOGICALAND, ELU, + LOGICALOR, RELU6, SOFTPLUS, + SOFTSIGN, LESS, LESSEQUAL, + BESSELI1E, GREATEREQUAL, APPROXIMATEEQUAL}; auto iter = elementwise_op.find(op_name); return (iter != elementwise_op.end()); } @@ -265,7 +281,10 @@ bool IsSplittableOperator(const std::string &op_name) { LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, CONCAT, STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, SPARSE_GATHERV2, TILE, DROPOUT, SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO}; + EMBEDDING_LOOKUP, FUSE_BATCH_NORM_EX, SPLIT, BROADCAST_TO, ABS, ACOSH, ASIN, ASINH, ATAN, ATANH, CEIL, COSH, + EXPM1, LOG1P, SIN, SINH, TAN, RSQRT, INV, RECIPROCAL, ROUND, FLOOR, SIGN, ERF, ERFC, ZEROSLIKE, ONESLIKE, + BESSELI0E, BESSELI1E, FLOORMOD, ASSIGN, ASSIGN_ADD, ATAN2, DIVNONAN, LOGICALAND, LOGICALOR, ELU, RELU6, RELUV2, + SOFTPLUS, SOFTSIGN, GREATEREQUAL, LESSEQUAL, LESS, APPROXIMATEEQUAL, MOD}; // clang-format on auto iter = splittable_op.find(op_name);