diff --git a/example/resnet50_imagenet2012/run_distribute_train.sh b/example/resnet50_imagenet2012/run_distribute_train.sh index 235a48e9c8..22157608e6 100755 --- a/example/resnet50_imagenet2012/run_distribute_train.sh +++ b/example/resnet50_imagenet2012/run_distribute_train.sh @@ -16,7 +16,7 @@ if [ $# != 2 ] && [ $# != 3 ] then - echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" + echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] [PRETRAINED_CKPT_PATH](optional)" exit 1 fi @@ -32,7 +32,7 @@ PATH1=$(get_real_path $1) PATH2=$(get_real_path $2) if [ $# == 3 ] then - PATH3=$(get_real_path $3) + PATH3=$(get_real_path $3) fi if [ ! -f "$PATH1" ] @@ -47,11 +47,11 @@ then exit 1 fi -if [ ! -f "$PATH3" ] -then +if [ $# == 3 ] && [ ! -f "$PATH3" ] +then echo "error: PRETRAINED_CKPT_PATH=$PATH3 is not a file" exit 1 -fi +fi ulimit -u unlimited export DEVICE_NUM=8 diff --git a/example/resnet50_imagenet2012/run_infer.sh b/example/resnet50_imagenet2012/run_infer.sh index 14d7faf981..1482b63f5f 100755 --- a/example/resnet50_imagenet2012/run_infer.sh +++ b/example/resnet50_imagenet2012/run_infer.sh @@ -34,13 +34,13 @@ PATH2=$(get_real_path $2) if [ ! -d $PATH1 ] then - echo "error: DATASET_PATH=$1 is not a directory" + echo "error: DATASET_PATH=$PATH1 is not a directory" exit 1 fi if [ ! -f $PATH2 ] then - echo "error: CHECKPOINT_PATH=$2 is not a file" + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" exit 1 fi diff --git a/example/resnet50_imagenet2012/run_standalone_train.sh b/example/resnet50_imagenet2012/run_standalone_train.sh index c4dc95b7eb..e0eb5efaf0 100755 --- a/example/resnet50_imagenet2012/run_standalone_train.sh +++ b/example/resnet50_imagenet2012/run_standalone_train.sh @@ -31,17 +31,17 @@ get_real_path(){ PATH1=$(get_real_path $1) if [ $# == 2 ] then - PATH2=$(get_real_path $2) + PATH2=$(get_real_path $2) fi if [ ! -d "$PATH1" ] then echo "error: DATASET_PATH=$PATH1 is not a directory" exit 1 -fi +fi -if [ ! -f "$PATH2" ] -then +if [ $# == 2 ] && [ ! -f "$PATH2" ] +then echo "error: PRETRAINED_CKPT_PATH=$PATH2 is not a file" exit 1 fi @@ -62,7 +62,7 @@ cd ./train || exit echo "start training for device $DEVICE_ID" env > env.log if [ $# == 1 ] -then +then python train.py --do_train=True --dataset_path=$PATH1 &> log & else python train.py --do_train=True --dataset_path=$PATH1 --pre_trained=$PATH2 &> log & diff --git a/mindspore/ccsrc/operator/ops.cc b/mindspore/ccsrc/operator/ops.cc index f267b9a73e..12c4afd894 100755 --- a/mindspore/ccsrc/operator/ops.cc +++ b/mindspore/ccsrc/operator/ops.cc @@ -246,7 +246,6 @@ const PrimitivePtr kPrimNotInDict = std::make_shared("not_in_dict"); const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOperator"); const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); -const PrimitivePtr kPrimAllReduce = std::make_shared("AllReduce"); // Debug ops const PrimitivePtr kPrimScalarSummary = std::make_shared("ScalarSummary"); diff --git a/mindspore/ccsrc/operator/ops.h b/mindspore/ccsrc/operator/ops.h index 5e1564f2c1..e67342527c 100755 --- a/mindspore/ccsrc/operator/ops.h +++ b/mindspore/ccsrc/operator/ops.h @@ -252,7 +252,6 @@ extern const PrimitivePtr kPrimInDict; extern const PrimitivePtr kPrimNotInDict; // Comm ops -extern const PrimitivePtr kPrimAllReduce; extern const PrimitivePtr kPrimMirror; extern const PrimitivePtr kPrimVirtualDiv; extern const PrimitivePtr kPrimVirtualDataset; diff --git a/mindspore/ccsrc/optimizer/irpass.cc b/mindspore/ccsrc/optimizer/irpass.cc index 2ac0bc21c7..4a867546b6 100644 --- a/mindspore/ccsrc/optimizer/irpass.cc +++ b/mindspore/ccsrc/optimizer/irpass.cc @@ -54,7 +54,6 @@ OptimizeIRPassLib::OptimizeIRPassLib() { {prim::kPrimInsertGradientOf, prim::kPrimHookBackward, prim::kPrimPrintShapeType, prim::kPrimGetRefKey, prim::kPrimMirror, prim::kPrimVirtualDiv}); zero_like_fill_zero_ = MakeSubstitution(ZeroLikeFillZero(), "zero_like_fill_zero", prim::kPrimZerosLikeTensor); - adjust_all_reduce_mul_add_ = MakeSubstitution(AdjustAllReduceMulAdd(), "adjust_all_reduce_mul_add", prim::kPrimAddN); // ops eliminate item_tuple_eliminate_ = diff --git a/mindspore/ccsrc/optimizer/irpass.h b/mindspore/ccsrc/optimizer/irpass.h index e834d69b69..02bfee65d6 100644 --- a/mindspore/ccsrc/optimizer/irpass.h +++ b/mindspore/ccsrc/optimizer/irpass.h @@ -35,7 +35,6 @@ class OptimizeIRPassLib { SubstitutionPtr arithmetic_simplify_; SubstitutionPtr special_op_eliminate_; SubstitutionPtr zero_like_fill_zero_; - SubstitutionPtr adjust_all_reduce_mul_add_; // ops eliminate SubstitutionPtr item_tuple_eliminate_; diff --git a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h index b33b4c613d..ab191aab20 100644 --- a/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h +++ b/mindspore/ccsrc/optimizer/irpass/arithmetic_simplify.h @@ -228,115 +228,6 @@ class ConstantDuplicateMul : public AnfVisitor { CNodePtr cnode_; }; -// grad = AllReduce(grad) / worker_number -// grad = grad + weight * decy -// -> -// grad = grad + weight * decy -// grad = AllReduce(grad) / worker_number - -// {prim::kPrimAddN, {prim::kPrimMakeTuple, {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y}, Z}} -> -// {prim::kPrimMul, {prim::kPrimAllReduce, {prim::kPrimAddN,{prim::kPrimMakeTuple, Z, X}}}, Y} -class AdjustAllReduceMulAdd : public AnfVisitor { - public: - AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { - Reset(); - // {prim::kPrimAddN, Zs} - if (!IsPrimitiveCNode(node, prim::kPrimAddN)) { - return nullptr; - } - auto addn = node->cast(); - if (addn->size() != 2) { - return nullptr; - } - AnfVisitor::Match(prim::kPrimMakeTuple, {IsNode, IsNode})(addn->input(1)); - if (x_ == nullptr || y_ == nullptr || z_ == nullptr || all_reduce_fg_ == nullptr) { - return nullptr; - } - auto addn_maketuple = addn->input(1); - - auto fg = all_reduce_fg_; - // addn inputs cross the graph, make the inputs same as allreduce node. - if (z_->isa() && fg != z_->func_graph()) { - auto cnode_z = z_->cast(); - z_ = NewCNode(cnode_z->inputs(), fg); - } - - auto addn_op_node = addn->input(0); - auto make_tuple_op_node = addn->input(1)->cast()->input(0); - - AnfNodePtr tuple = NewCNode({make_tuple_op_node, z_, x_}, fg); - AnfNodePtr add = NewCNode({addn_op_node, tuple}, fg); - AnfNodePtr all_reduce = NewCNode({all_reduce_, add}, fg); - AnfNodePtr mul = NewCNode({mul_, all_reduce, y_}, fg); - ProcessDependEdge(fg, addn_maketuple, all_reduce); - return mul; - } - void ProcessDependEdge(const FuncGraphPtr &fg, const AnfNodePtr &addn_maketuple, const AnfNodePtr &new_node) { - // If has dynamic loss scale. - auto &users_map = fg->manager()->node_users(); - auto it = users_map.find(mul_cnode_); - if (it != users_map.end()) { - auto users = it->second; - for (auto &user_pair : users) { - auto node = user_pair.first; - if (node != addn_maketuple) { - if (IsPrimitiveCNode(node, prim::kPrimMakeTuple)) { - fg->manager()->SetEdge(node, user_pair.second, new_node); - } - } - } - } - } - void Visit(const AnfNodePtr &node) override { - if (level_ == 0) { - level_ = 1; - is_reduce_match_ = false; - // {prim::kPrimMul, {prim::kPrimAllReduce, X}, Y} - AnfVisitor::Match(prim::kPrimMul)(node); - level_ = 0; - if (is_reduce_match_) { - mul_ = node->cast()->input(0); - mul_cnode_ = node->cast(); - y_ = tmp_; - } else { - z_ = node; - } - } - - if (level_ == 1) { - // {prim::kPrimAllReduce, X} - if (IsPrimitiveCNode(node, prim::kPrimAllReduce)) { - auto cnode = node->cast(); - if (cnode->size() > 1) { - all_reduce_ = cnode->input(0); - x_ = cnode->input(1); - is_reduce_match_ = true; - all_reduce_fg_ = cnode->func_graph(); - } - } else { - tmp_ = node; - } - } - } - - void Reset() { - level_ = 0; - is_reduce_match_ = false; - x_ = nullptr; - y_ = nullptr; - z_ = nullptr; - tmp_ = nullptr; - all_reduce_fg_ = nullptr; - } - - private: - int level_{0}; - bool is_reduce_match_{false}; - AnfNodePtr x_{nullptr}, y_{nullptr}, z_{nullptr}, tmp_{nullptr}; - AnfNodePtr all_reduce_{nullptr}, mul_{nullptr}, mul_cnode_{nullptr}; - FuncGraphPtr all_reduce_fg_{nullptr}; -}; - class ArithmeticSimplify { public: ArithmeticSimplify() diff --git a/mindspore/ccsrc/pipeline/parse/function_block.h b/mindspore/ccsrc/pipeline/parse/function_block.h index 5341d33b21..e7842903ee 100644 --- a/mindspore/ccsrc/pipeline/parse/function_block.h +++ b/mindspore/ccsrc/pipeline/parse/function_block.h @@ -28,7 +28,6 @@ #include #include "pipeline/parse/parse_base.h" #include "utils/log_adapter.h" -#include "utils/ordered_map.h" namespace mindspore { namespace parse { @@ -100,7 +99,7 @@ class FunctionBlock : public std::enable_shared_from_this { std::unordered_map removable_phis_; // set state nodes need to insert before function return nodes. - OrderedMap state_assign_; + std::unordered_map state_assign_; // hold declared global variables in function std::set global_vars_; diff --git a/mindspore/ccsrc/pipeline/pass.cc b/mindspore/ccsrc/pipeline/pass.cc index a95c02ded6..0a5af9e3df 100644 --- a/mindspore/ccsrc/pipeline/pass.cc +++ b/mindspore/ccsrc/pipeline/pass.cc @@ -82,7 +82,6 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { // Arithmetic simplifications irpass.arithmetic_simplify_, irpass.addn_zero_filter_, - irpass.adjust_all_reduce_mul_add_, // Miscellaneous irpass.item_tuple_eliminate_, diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index 74fd3c3b3e..19a9ffd79d 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -1275,7 +1275,7 @@ class UnsortedSegmentSum(PrimitiveWithInfer): Tensor, the shape is :math:`(z, x_{N+1}, ..., x_R)`. Examples: - >>> input_x = Tensor([1, 2, 3, 4], mindspore.float32) + >>> input_x = Tensor([1, 2, 3, 4], mindspore.float) >>> segment_ids = Tensor([0, 0, 1, 2], mindspore.int32) >>> num_segments = 4 >>> P.UnsortedSegmentSum()(input_x, segment_ids, num_segments) diff --git a/mindspore/ops/operations/nn_ops.py b/mindspore/ops/operations/nn_ops.py index 962df9e7eb..a0d9682d17 100644 --- a/mindspore/ops/operations/nn_ops.py +++ b/mindspore/ops/operations/nn_ops.py @@ -1855,7 +1855,7 @@ class LayerNorm(Primitive): `Layer Normalization `_. .. math:: - y = \frac{x - mean}{\sqrt{variance + \epsilon}} * \gamma + \beta + y = \frac{x - mean]}{\sqrt{variance + \epsilon}} * \gamma + \beta where :math:`\gamma` is scale, :math:`\beta` is bias, :math:`\epsilon` is epsilon. diff --git a/mindspore/ops/primitive.py b/mindspore/ops/primitive.py index 908c0245bb..95e148204b 100644 --- a/mindspore/ops/primitive.py +++ b/mindspore/ops/primitive.py @@ -284,8 +284,7 @@ def prim_attr_register(fn): def constexpr(fn=None, get_instance=True, name=None): """ - Makes a PrimitiveWithInfer operator, which infer the value while compiling. We can define a function - to compute between constant variable and used in constructß. + Makes a PrimitiveWithInfer operator, which infer the value while compiling. Args: fn (function): A `fn` use as the infer_value of the output operator. diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 1ed1fed43d..2d4cf0e78e 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -556,24 +556,5 @@ TEST_F(TestOptLib, test_constant_duplicate_mul) { ASSERT_TRUE(CheckOpt(beforerl, after, patterns)); ASSERT_TRUE(CheckOpt(beforerr, after, patterns)); } - -TEST_F(TestOptLib, test_adjust_allreduce_mul_add) { - FuncGraphPtr beforell = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforell"); - FuncGraphPtr beforelr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforelr"); - FuncGraphPtr beforerl = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerl"); - FuncGraphPtr beforerr = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "beforerr"); - FuncGraphPtr after1 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after1"); - FuncGraphPtr before2r = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2r"); - FuncGraphPtr before2l = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "before2l"); - FuncGraphPtr after2 = getPyFun.CallAndParseRet("test_adjust_allreduce_mul_add", "after2"); - auto patterns = std::vector({irpass.adjust_all_reduce_mul_add_}); - ASSERT_TRUE(CheckOpt(beforell, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforelr, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforerl, after1, patterns)); - ASSERT_TRUE(CheckOpt(beforerr, after1, patterns)); - ASSERT_TRUE(CheckOpt(before2l, after2, patterns)); - ASSERT_TRUE(CheckOpt(before2r, after2, patterns)); -} - } // namespace opt } // namespace mindspore diff --git a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py index 28a3d8e7d8..28543043e7 100644 --- a/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py +++ b/tests/ut/cpp/python_input/gtest_input/optimizer/opt_test.py @@ -1045,8 +1045,8 @@ def test_print_tuple_wrapper(tag): def test_constant_duplicate_mul(tag): fns = FnDict() - Mul = Primitive('Mul') - Sqrt = Primitive('Sqrt') + Mul = Primitive('Mul'); + Sqrt = Primitive('Sqrt'); x = Tensor(np.array([[2, 2], [2, 3]]).astype('float32')) tensor1 = Tensor(np.array([[1.2, 2.1], [2.2, 3.2]]).astype('float32')) @@ -1073,44 +1073,3 @@ def test_constant_duplicate_mul(tag): return Mul(Sqrt(x), Mul(tensor1, tensor2)) return fns[tag] - - -def test_adjust_allreduce_mul_add(tag): - fns = FnDict() - Mul = Primitive('Mul') - AddN = Primitive('AddN') - AllReduce = Primitive('AllReduce') - - @fns - def beforell(x, y, z): - return AddN((z, Mul(y, AllReduce(x)))) - - @fns - def beforelr(x, y, z): - return AddN((z, Mul(AllReduce(x), y))) - - @fns - def beforerl(x, y, z): - return AddN((Mul(y, AllReduce(x)), z)) - - @fns - def beforerr(x, y, z): - return AddN((Mul(AllReduce(x), y), z)) - - @fns - def after1(x, y, z): - return Mul(AllReduce(AddN((z, x))), y) - - @fns - def before2r(x, y, z): - return AddN((Mul(AllReduce(x), y), Mul(z, z))) - - @fns - def before2l(x, y, z): - return AddN((Mul(z, z), Mul(AllReduce(x), y))) - - @fns - def after2(x, y, z): - return Mul(AllReduce(AddN((Mul(z, z), x))), y) - - return fns[tag] diff --git a/tests/ut/python/train/test_amp.py b/tests/ut/python/train/test_amp.py index 54f2081b6c..fe08809be1 100644 --- a/tests/ut/python/train/test_amp.py +++ b/tests/ut/python/train/test_amp.py @@ -20,14 +20,9 @@ import mindspore.context as context from mindspore import Tensor from mindspore import amp from mindspore import nn -from mindspore.train import Model, ParallelMode -from mindspore import Tensor -from mindspore.common import dtype as mstype -import mindspore.context as context -from mindspore.model_zoo.resnet import resnet50 +from mindspore.train import Model from ....dataset_mock import MindData -from mindspore.parallel._auto_parallel_context import auto_parallel_context -from mindspore.communication.management import init + def setup_module(module): context.set_context(mode=context.GRAPH_MODE) @@ -143,22 +138,3 @@ def test_compile_model_train_O2(): with pytest.raises(ValueError): # not actual run, the metrics step will fail, check if compile ok. model.eval(dataset) - -def test_compile_model_train_O2_parallel(): - dataset_types = (np.float32, np.float32) - dataset_shapes = ((16, 16), (16, 16)) - - dataset = MindDataSet(dataset_types, dataset_shapes) - - net = NetNoLoss(16, 16) - loss = nn.MSELoss() - optimizer = nn.Momentum(net.trainable_params(), 0.1, 0.9, 0.00004, 1024.0) - - context.set_auto_parallel_context( - global_rank=0, device_num=8, - mirror_mean=True, parameter_broadcast=True, - parallel_mode=ParallelMode.DATA_PARALLEL) - init() - - model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={"acc"}, amp_level="O2") - model.train(2, dataset, dataset_sink_mode=False)