From 7c237620ba6642c2b6f22a0cfbffa2b2b3d62336 Mon Sep 17 00:00:00 2001 From: yangzhenzhang <285824651@qq.com> Date: Mon, 25 May 2020 19:32:23 +0800 Subject: [PATCH] add sigmoid op --- mindspore/ccsrc/parallel/dynamic_creator.h | 1 + .../ccsrc/parallel/ops_info/activation_info.h | 8 ++ .../ccsrc/parallel/step_auto_parallel.cc | 91 ++++--------------- .../parallel/test_auto_parallel_activation.py | 55 +++++++++++ 4 files changed, 82 insertions(+), 73 deletions(-) create mode 100644 tests/ut/python/parallel/test_auto_parallel_activation.py diff --git a/mindspore/ccsrc/parallel/dynamic_creator.h b/mindspore/ccsrc/parallel/dynamic_creator.h index 7b09193dbd..4fd5f34cf2 100644 --- a/mindspore/ccsrc/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/parallel/dynamic_creator.h @@ -122,6 +122,7 @@ REGISTER(AssignSubInfo); REGISTER(ReLUInfo); REGISTER(GatherV2Info); REGISTER(SqrtInfo); +REGISTER(SigmoidInfo); REGISTER(GetNextInfo); REGISTER(NegInfo); REGISTER(BatchMatMulInfo); diff --git a/mindspore/ccsrc/parallel/ops_info/activation_info.h b/mindspore/ccsrc/parallel/ops_info/activation_info.h index a71c6b6df7..cd66bf8e8b 100644 --- a/mindspore/ccsrc/parallel/ops_info/activation_info.h +++ b/mindspore/ccsrc/parallel/ops_info/activation_info.h @@ -211,6 +211,14 @@ class SquareInfo : public ActivationOther { : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} ~SquareInfo() override = default; }; + +class SigmoidInfo : public ActivationOther { + public: + SigmoidInfo(const std::string &name, const Shapes &inputs_shape, const Shapes &outputs_shape, + const PrimitiveAttrs &attrs) + : ActivationOther(name, inputs_shape, outputs_shape, attrs) {} + ~SigmoidInfo() override = default; +}; } // namespace parallel } // namespace mindspore #endif // MINDSPORE_CCSRC_PARALLEL_OPS_INFO_ACTIVATION_INFO_H_ diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 3811efdd6a..f842cabef4 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -48,74 +48,6 @@ namespace mindspore { namespace parallel { -// splittable_op_ will continuously be updated -std::vector splittable_op_ = {MATMUL, - GELU, - TANH, - SOFTMAX, - LOG_SOFTMAX, - ACTIVATION, - PRELU, - FLOORDIV, - L2_NORMALIZE, - TRANSPOSE, - RESHAPE, - TENSOR_ADD, - SUB, - MUL, - DIV, - GREATER, - MAXPOOL, - MAXPOOLV2, - VIRTUAL_DATA_SET, - SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - RELU, - ONEHOT, - DROPOUT_DO_MASK, - REDUCE_MAX, - REDUCE_MIN, - ARGMAXWITHVALUE, - ARGMINWITHVALUE, - REDUCE_SUM, - CONV2D, - FUSE_BATCH_NORM, - POOLING, - SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, - SIGMOID_CROSS_ENTROPY_WITH_LOGITS, - MAX_POOL_WITH_ARGMAX, - SIMPLE_MEAN, - FLATTEN, - BATCH_NORM, - LAYER_NORM, - BIAS_ADD, - ASSIGN_SUB, - COS, - ACOS, - EXP, - LOG, - REDUCE_MEAN, - REAL_DIV, - SIGMOID, - POW, - MAXIMUM, - MINIMUM, - EQUAL, - NOT_EQUAL, - LOGICALNOT, - GATHERV2, - STRIDEDSLICE, - SQRT, - GET_NEXT, - CAST, - NEG, - SQUARE, - BATCH_MATMUL, - EXPAND_DIMS, - SQUEEZE}; - -std::vector elementwise_op_ = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, SQRT, CAST, - POW, EXP, LOG, COS, ACOS, LOGICALNOT, NEG, SQUARE}; - bool StepAutoParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &) { MS_EXCEPTION_IF_NULL(root); MS_EXCEPTION_IF_NULL(ParallelContext::GetInstance()); @@ -314,14 +246,27 @@ std::vector ExtractOutputTypeByNode(const CNodePtr &node) { } bool IsElementWiseOperator(const std::string &op_name) { - auto iter = std::find(elementwise_op_.begin(), elementwise_op_.end(), op_name); - return (iter != elementwise_op_.end()); + static const std::set elementwise_op = {ACTIVATION, GELU, TANH, SOFTMAX, LOG_SOFTMAX, RELU, + SQRT, CAST, POW, EXP, LOG, COS, + ACOS, LOGICALNOT, NEG, SQUARE, SIGMOID}; + auto iter = elementwise_op.find(op_name); + return (iter != elementwise_op.end()); } bool IsSplittableOperator(const std::string &op_name) { - std::vector::iterator iter; - iter = std::find(splittable_op_.begin(), splittable_op_.end(), op_name); - return (iter != splittable_op_.end()); + // clang-format off + static const std::set splittable_op = + {MATMUL, TRANSPOSE, GELU, TANH, SOFTMAX, SUB, MUL, DIV, RESHAPE, GREATER, LOG_SOFTMAX, ACTIVATION, PRELU, + FLOORDIV, L2_NORMALIZE, TENSOR_ADD, MAXPOOL, MAXPOOLV2, VIRTUAL_DATA_SET, RELU, ONEHOT, DROPOUT_DO_MASK, + REDUCE_MAX, REDUCE_MIN, ARGMAXWITHVALUE, ARGMINWITHVALUE, REDUCE_SUM, CONV2D, FUSE_BATCH_NORM, POOLING, + MAX_POOL_WITH_ARGMAX, SIMPLE_MEAN, FLATTEN, BATCH_NORM, LAYER_NORM, BIAS_ADD, ASSIGN_SUB, COS, ACOS, EXP, + LOG, REDUCE_MEAN, REAL_DIV, SIGMOID, POW, MAXIMUM, MINIMUM, EQUAL, NOT_EQUAL, LOGICALNOT, GATHERV2, SQRT, + STRIDEDSLICE, GET_NEXT, CAST, NEG, SQUARE, BATCH_MATMUL, EXPAND_DIMS, SQUEEZE, + SOFTMAX_CROSS_ENTROPY_WITH_LOGITS, SIGMOID_CROSS_ENTROPY_WITH_LOGITS, SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS}; + // clang-format on + + auto iter = splittable_op.find(op_name); + return (iter != splittable_op.end()); } bool IsAutoParallelCareNode(const CNodePtr &cnode) { diff --git a/tests/ut/python/parallel/test_auto_parallel_activation.py b/tests/ut/python/parallel/test_auto_parallel_activation.py new file mode 100644 index 0000000000..815411dc16 --- /dev/null +++ b/tests/ut/python/parallel/test_auto_parallel_activation.py @@ -0,0 +1,55 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, mul_weight, strategy1=None, strategy2=None): + super().__init__() + self.mul = P.Mul().set_strategy(strategy1) + self.sigmoid = P.Sigmoid().set_strategy(strategy2) + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.mul(x, self.mul_weight) + out = self.sigmoid(out) + return out + + +_x = Tensor(np.ones([64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([64, 32]), dtype=ms.float32) + + +def compile_net(net): + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_auto_parallel_activation(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + strategy1 = ((4, 4), (4, 4)) + strategy2 = None + net = Net(_w1, strategy1, strategy2) + compile_net(net)