From ec9793861f5e2304637c3a363d847f78c9350297 Mon Sep 17 00:00:00 2001 From: Ziyan Date: Tue, 26 Jan 2021 15:46:58 +0800 Subject: [PATCH] fix grad accu --- .../ccsrc/backend/session/session_basic.cc | 10 +- mindspore/ccsrc/frontend/optimizer/irpass.cc | 1 + mindspore/ccsrc/frontend/optimizer/irpass.h | 1 + .../optimizer/irpass/special_op_eliminate.h | 26 ++ mindspore/ccsrc/frontend/parallel/context.cc | 26 +- .../frontend/parallel/ops_info/ops_utils.h | 2 + .../ccsrc/frontend/parallel/step_parallel.cc | 129 +++++++- mindspore/ccsrc/pipeline/jit/pass.cc | 1 + mindspore/core/base/core_ops.h | 1 + mindspore/ops/_grad/grad_comm_ops.py | 46 ++- mindspore/ops/operations/__init__.py | 2 +- mindspore/ops/operations/comm_ops.py | 13 + model_zoo/official/nlp/gpt/src/gpt.py | 1 + .../parallel/test_get_parameter_layout.py | 1 + .../python/parallel/test_grad_accumulation.py | 307 ------------------ 15 files changed, 222 insertions(+), 345 deletions(-) delete mode 100644 tests/ut/python/parallel/test_grad_accumulation.py diff --git a/mindspore/ccsrc/backend/session/session_basic.cc b/mindspore/ccsrc/backend/session/session_basic.cc index 6d0a10981b..e8c840046b 100644 --- a/mindspore/ccsrc/backend/session/session_basic.cc +++ b/mindspore/ccsrc/backend/session/session_basic.cc @@ -51,7 +51,7 @@ namespace mindspore { namespace session { -static std::shared_ptr> python_paras; +static std::shared_ptr> python_paras; void ClearPythonParasMap() { python_paras = nullptr; } namespace { const int kSummaryGetItem = 2; @@ -106,7 +106,7 @@ bool CheckIfNeedCreateOutputTensor(const AnfNodePtr &node) { return false; } -ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { +ParamInfoPtr GetParamDefaultValue(const AnfNodePtr &node) { if (node == nullptr) { return nullptr; } @@ -114,7 +114,7 @@ ValuePtr GetParamDefaultValue(const AnfNodePtr &node) { if (parameter == nullptr || !parameter->has_default()) { return nullptr; } - return parameter->default_param(); + return parameter->param_info(); } tensor::TensorPtr CreateCNodeOutputTensor(const session::KernelWithIndex &node_output_pair, @@ -747,7 +747,7 @@ ParameterPtr SessionBasic::CreateNewParameterFromParameter(const AnfNodePtr &anf ParameterPtr new_parameter = nullptr; // if parameter's python parameter has been exist a backend parameter, reuse the exist parameter if (python_paras == nullptr) { - python_paras = std::make_shared>(); + python_paras = std::make_shared>(); } auto iter = python_paras->find(param_value); if (iter != python_paras->end()) { @@ -1217,7 +1217,7 @@ ParameterPtr SessionBasic::CreateNewParameter(const AnfNodePtr &anf, KernelGraph auto param_value = GetParamDefaultValue(anf); ParameterPtr new_parameter = nullptr; if (python_paras == nullptr) { - python_paras = std::make_shared>(); + python_paras = std::make_shared>(); } auto iter = python_paras->find(param_value); if (iter != python_paras->end()) { diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.cc b/mindspore/ccsrc/frontend/optimizer/irpass.cc index af1f6a5ca4..ac1e262a97 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.cc +++ b/mindspore/ccsrc/frontend/optimizer/irpass.cc @@ -88,6 +88,7 @@ OptimizeIRPassLib::OptimizeIRPassLib() { prim::kPrimMirrorMiniStep); mini_step_allgather_replace_ = MakeSubstitution(std::make_shared(), "mini_step_allgather_replace", prim::kPrimMiniStepAllGather); + virtual_add_elim_ = MakeSubstitution(std::make_shared(), "virtual add", prim::kPrimVirtualAdd); check_bprop_eliminate_ = MakeSubstitution(std::make_shared(), "check_bprop_eliminate", prim::kPrimCheckBprop); reset_defer_inline_ = diff --git a/mindspore/ccsrc/frontend/optimizer/irpass.h b/mindspore/ccsrc/frontend/optimizer/irpass.h index 1da28f1681..b3604a2cdb 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass.h @@ -52,6 +52,7 @@ class OptimizeIRPassLib { SubstitutionPtr depend_value_elim_; SubstitutionPtr all_reduce_const_elim_; SubstitutionPtr mirror_mini_step_elim_; + SubstitutionPtr virtual_add_elim_; SubstitutionPtr mini_step_allgather_replace_; // Env Item Eliminate diff --git a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h index 41fb26c840..913834f8ca 100644 --- a/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h +++ b/mindspore/ccsrc/frontend/optimizer/irpass/special_op_eliminate.h @@ -175,6 +175,25 @@ class MirrorMiniStepEliminater : public AnfVisitor { void Visit(const AnfNodePtr &) override {} }; +// {prim::kPrimVirtualAdd, X, Z} -> X +class VirtualAddEliminater : public AnfVisitor { + public: + AnfNodePtr operator()(const OptimizerPtr &, const AnfNodePtr &node) override { + if (!IsPrimitiveCNode(node, prim::kPrimVirtualAdd) || node->func_graph() == nullptr) { + return nullptr; + } + + auto &inputs = node->cast()->inputs(); + if (inputs.size() < 2) { + return nullptr; + } + + return inputs[1]; + } + + void Visit(const AnfNodePtr &) override {} +}; + // {prim::kPrimMiniStepAllGather, X, Z} -> {prim::kPrimAllGather, X} class MiniStepAllGatherPass : public AnfVisitor { public: @@ -191,8 +210,15 @@ class MiniStepAllGatherPass : public AnfVisitor { MS_EXCEPTION_IF_NULL(prim); auto attrs = prim->attrs(); std::string group = attrs[parallel::GROUP]->ToString(); + auto fusion = attrs[parallel::FUSION]; parallel::Operator op = parallel::CreateAllGatherOp(group); std::vector node_input = parallel::CreateInput(op, inputs[1], parallel::PARALLEL_OPTIMIZER_ALLGATHER); + auto prim_anf_node = node_input[0]->cast(); + prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + attrs = prim->attrs(); + attrs[parallel::FUSION] = fusion; + prim->SetAttrs(attrs); auto func_graph = inputs[1]->func_graph(); CNodePtr new_node = func_graph->NewCNode(node_input); return new_node; diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index 9e832b09be..8002546194 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -155,13 +155,23 @@ const std::vector ParallelContext::GetAllReduceFusionSplitSizes(const // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode void ParallelContext::ParallelParameterContextInitShape(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(func_graph); - if (func_graph->has_flag(AUTO_PARALLEL) && - (!func_graph->has_flag(TRAINING) || - (ParallelContext::GetInstance()->grad_accumulation_step() > 1 && !func_graph->has_flag(ACCUMULATION)))) { + if (!func_graph->has_flag(AUTO_PARALLEL)) { + return; + } + + if (!func_graph->has_flag(TRAINING)) { + init_param_shape_ = false; + MS_LOG(INFO) << "In parallel evaluation or prediction, may be need to restore the parameter shape"; + return; + } + + if ((ParallelContext::GetInstance()->grad_accumulation_step() > 1) && !func_graph->has_flag(ACCUMULATION)) { init_param_shape_ = false; + MS_LOG(INFO) << "In parallel grad accumulation second graph, need to restore the parameter shape"; } else { param_shapes.clear(); init_param_shape_ = true; + MS_LOG(INFO) << "Init the parameter shape dict"; } } @@ -171,6 +181,10 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL)) { + return; + } + if (init_param_shape_) { return; } @@ -182,7 +196,7 @@ void ParallelContext::ParallelParameterContextRestoreShape(const FuncGraphPtr &f Shape shape = iter->second; std::shared_ptr base_shape = std::make_shared(shape); ptr->set_shape(base_shape); - MS_LOG(DEBUG) << "The parameter name is " << param_node->name() << ", the shape is " << shape; + MS_LOG(INFO) << "The parameter name is " << param_node->name() << ", the shape is " << shape; } // Clear param_shapes before training in auto-parallel or semi-auto-parallel mode @@ -192,6 +206,10 @@ void ParallelContext::ParallelParameterContextCkptShape(const FuncGraphPtr &func MS_EXCEPTION_IF_NULL(func_graph); MS_EXCEPTION_IF_NULL(param_node); MS_EXCEPTION_IF_NULL(ptr); + if (!func_graph->has_flag(AUTO_PARALLEL)) { + return; + } + if (!init_param_shape_) { return; } diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h index cece82022c..f5708b4df1 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h +++ b/mindspore/ccsrc/frontend/parallel/ops_info/ops_utils.h @@ -110,6 +110,8 @@ constexpr char STRIDES[] = "strides"; constexpr char GROUP[] = "group"; constexpr char FUSION[] = "fusion"; constexpr char DO_MIRROR[] = "do_mirror"; +constexpr char RECOMPUTE[] = "recompute"; +constexpr char RECOMPUTE_COMM_OP[] = "recompute_comm_op"; constexpr char NUM_SAMPLED[] = "num_sampled"; constexpr char NUM_TRUE[] = "num_true"; constexpr char SEED[] = "seed"; diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 99c90100a4..fe011ddc4b 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -97,6 +97,27 @@ void SetMiniStepOpDoMirrorLabel(std::vector new_node_input, bool acc prim->SetAttrs(attrs); } +void SetAllReduceRecomputeFlag(const std::vector &new_node_input, const CNodePtr &node) { + if (new_node_input.empty()) { + return; + } + + auto prim_anf_node = new_node_input[0]->cast(); + auto prim = GetValueNode(prim_anf_node); + MS_EXCEPTION_IF_NULL(prim); + auto attrs = prim->attrs(); + + auto anf_node = node->input(0)->cast(); + auto prim_node = GetValueNode(anf_node); + MS_EXCEPTION_IF_NULL(prim_node); + auto node_attrs = prim_node->attrs(); + if (node_attrs.find(RECOMPUTE_COMM_OP) != node_attrs.end() && !GetValue(node_attrs[RECOMPUTE_COMM_OP])) { + attrs[RECOMPUTE] = MakeValue(false); + prim->SetAttrs(attrs); + MS_LOG(INFO) << "Do not recompute the forward communication operator of " << prim_node->ToString(); + } +} + std::vector CreateInput(const Operator &op, const AnfNodePtr &node, const std::string &instance_name) { MS_EXCEPTION_IF_NULL(node); OperatorArgs arg_forward = op.second; @@ -353,6 +374,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { std::string instance_name_base = FORWARD_OP; std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); + SetAllReduceRecomputeFlag(forward_input, node_to_insert); CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode MS_EXCEPTION_IF_NULL(forward_node); ScopePtr scope = node->scope(); @@ -1165,7 +1187,14 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons // not a RefKey if (!param_node_pair.second) { - auto next_cnode = FindCNode(param_node_pair.first, MIRROR_OPERATOR, func_graph); + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + std::string mirror_op_name; + if (grad_accumulation_step > 1) { + mirror_op_name = MIRROR_MINI_STEP_OPERATOR; + } else { + mirror_op_name = MIRROR_OPERATOR; + } + auto next_cnode = FindCNode(param_node_pair.first, mirror_op_name, func_graph); // if there is already a MirrorOp in the same graph, use MirrorOp CNode as a input instead if (next_cnode.first) { MS_EXCEPTION_IF_NULL(next_cnode.second); @@ -1743,6 +1772,10 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) { if (found_be_cloned_parameter) { // set the shape and tensor layout for cloned parameter std::string param_name = cloned_parameter_node->cast()->name(); + if (cloned_from_parameter->user_data() == nullptr) { + MS_LOG(WARNING) << "The parameter " << param_name << " has not tensor layout, skip it"; + continue; + } cloned_parameter->set_user_data(cloned_from_parameter->user_data()); MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract()); MS_EXCEPTION_IF_NULL(cloned_from_node->abstract()); @@ -3298,6 +3331,97 @@ static void HandleNoUsedParameter(const FuncGraphPtr &root) { } } +static bool IsFullySplitParameter(const ParameterPtr ¶m_ptr) { + auto tensor_layout = param_ptr->user_data(); + if (tensor_layout == nullptr) { + return false; + } + + auto dev_mat_shape = tensor_layout->device_arrangement().array(); + auto tensor_map = tensor_layout->tensor_map().array(); + int64_t rank = g_device_manager->global_rank(); + RankList rank_list = g_device_manager->GetDeviceListInThisStage(); + DeviceMatrix dev_matrix(rank, rank_list, dev_mat_shape); + RankList group_devices; + if (dev_matrix.GetDevicesByTensorMap(tensor_map, &group_devices) != SUCCESS) { + MS_LOG(WARNING) << "Get devices by tensor map failed, invalid tensor layout"; + return false; + } + + if (group_devices.size() == 1) { + MS_LOG(INFO) << "The parameter: " << param_ptr->name() << " is fully split"; + return true; + } + return false; +} + +static AnfNodePtr FindGradAccuParameter(const std::vector ¶meters, const std::string &name) { + for (auto ¶meter : parameters) { + auto param_ptr = parameter->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + if (param_ptr->name() == name) { + continue; + } + if (param_ptr->name().find(name) != std::string::npos && param_ptr->name().find("accu_grad") != std::string::npos) { + return parameter; + } + } + return nullptr; +} + +static void InsertFullySplitParamGradAccu(const std::pair &node_user, + const FuncGraphManagerPtr &manager, const AnfNodePtr &accu_parameter) { + auto cnode = node_user.first->cast(); + auto prim = GetCNodePrimitive(cnode); + if (prim == nullptr) { + MS_LOG(WARNING) << cnode->DebugString() << " can not insert fully split param grad accumulation node"; + return; + } + OperatorAttrs attrs; + auto py_instance = CreatOpInstance(attrs, "_VirtualAdd", "grad_accu"); + auto value_node = NewValueNode(py_instance); + std::vector virtual_node_input = {value_node, cnode->input(node_user.second), accu_parameter}; + auto graph = cnode->func_graph(); + auto virtual_node = graph->NewCNode(virtual_node_input); + manager->SetEdge(cnode, node_user.second, virtual_node); +} + +static void HandleFullySplitParameters(const FuncGraphPtr &root) { + int64_t grad_accumulation_step = ParallelContext::GetInstance()->grad_accumulation_step(); + if ((grad_accumulation_step <= 1) || root->has_flag(ACCUMULATION)) { + return; + } + + auto parameters = root->parameters(); + auto node_users_map = root->manager()->node_users(); + for (auto ¶meter : parameters) { + auto param_ptr = parameter->cast(); + MS_EXCEPTION_IF_NULL(param_ptr); + + if (!IsFullySplitParameter(param_ptr)) { + continue; + } + + auto accu_parameter = FindGradAccuParameter(parameters, param_ptr->name()); + if (!accu_parameter) { + continue; // some parameters no need to handle, such as itself or lr + } + + auto node_users = node_users_map[parameter]; + for (auto &user : node_users) { + auto node = user.first; + auto cnode = node->cast(); + MS_EXCEPTION_IF_NULL(cnode); + if (!cnode->in_forward_flag()) { + continue; + } + InsertFullySplitParamGradAccu(user, root->manager(), accu_parameter); + MS_LOG(INFO) << "Insert full split assign add node for " << param_ptr->name(); + break; // only need to insert once, if the parameter has many users + } + } +} + bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) { #if (ENABLE_CPU && (ENABLE_D || ENABLE_GPU)) if (ps::PSContext::instance()->is_server() || ps::PSContext::instance()->is_scheduler()) { @@ -3390,6 +3514,9 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) MS_LOG(EXCEPTION) << "Save group info failed"; } + // handle full split parammeters in grad accumulation, do not contain optimizer-sharding's parameter + HandleFullySplitParameters(root); + DumpGraph(root, std::string(STEP_PARALLEL_END)); // step parallel only run once diff --git a/mindspore/ccsrc/pipeline/jit/pass.cc b/mindspore/ccsrc/pipeline/jit/pass.cc index a0013bb8c2..af0b2d6cf6 100644 --- a/mindspore/ccsrc/pipeline/jit/pass.cc +++ b/mindspore/ccsrc/pipeline/jit/pass.cc @@ -159,6 +159,7 @@ OptPassGroupMap GetOptPassesA(const opt::irpass::OptimizeIRPassLib &irpass) { irpass.switch_layer_defer_inline_, irpass.replace_applicator_, irpass.mirror_mini_step_elim_, + irpass.virtual_add_elim_, irpass.row_tensor_add_zeros_like_, irpass.mini_step_allgather_replace_, }); diff --git a/mindspore/core/base/core_ops.h b/mindspore/core/base/core_ops.h index c180ba60c9..6f989e775e 100644 --- a/mindspore/core/base/core_ops.h +++ b/mindspore/core/base/core_ops.h @@ -307,6 +307,7 @@ inline const PrimitivePtr kPrimMirror = std::make_shared("_MirrorOper inline const PrimitivePtr kPrimMirrorMiniStep = std::make_shared("_MirrorMiniStepOperator"); inline const PrimitivePtr kPrimMiniStepAllGather = std::make_shared("_MiniStepAllGather"); inline const PrimitivePtr kPrimVirtualDiv = std::make_shared("_VirtualDiv"); +inline const PrimitivePtr kPrimVirtualAdd = std::make_shared("_VirtualAdd"); inline const PrimitivePtr kPrimVirtualDataset = std::make_shared("_VirtualDataset"); inline const PrimitivePtr kPrimSend = std::make_shared("Send"); inline const PrimitivePtr kPrimReceive = std::make_shared("Receive"); diff --git a/mindspore/ops/_grad/grad_comm_ops.py b/mindspore/ops/_grad/grad_comm_ops.py index 425764f8d6..d41cc58420 100644 --- a/mindspore/ops/_grad/grad_comm_ops.py +++ b/mindspore/ops/_grad/grad_comm_ops.py @@ -22,7 +22,7 @@ from ...common.tensor import RowTensor from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..operations.comm_ops import (AllGather, _MiniStepAllGather, _HostAllGather, AllReduce, _AlltoAll, Broadcast, _GetTensorSlice, _MirrorOperator, _MirrorMiniStepOperator, ReduceOp, - ReduceScatter, _HostReduceScatter, _VirtualDiv, AllSwap) + ReduceScatter, _HostReduceScatter, _VirtualDiv, _VirtualAdd, AllSwap) from .grad_base import bprop_getters from ..operations._inner_ops import Send, Receive @@ -108,6 +108,14 @@ def get_bprop_receive(self): return bprop +@bprop_getters.register(_VirtualAdd) +def get_bprop_virtual_add(self): + """Generate bprop for _VirtualAdd""" + def bprop(x, grad_accu, out, dout): + return (dout + grad_accu, zeros_like(grad_accu)) + return bprop + + @bprop_getters.register(Broadcast) def get_bprop_broad_cast(self): """Generate bprop for Broadcast.""" @@ -168,13 +176,13 @@ def get_bprop_mini_step_all_gather(self): def bprop(x, z, out, dout): if do_mirror: if mean_flag: - tmp = z + dout - grad = all_reduce(tmp) + z = F.depend(z, F.assign_add(z, dout)) + grad = all_reduce(z) dx = split(grad)[rank] dx = F.tensor_mul(dx, scale) else: - tmp = z + dout - grad = all_reduce(tmp) + z = F.depend(z, F.assign_add(z, dout)) + grad = all_reduce(z) dx = split(grad)[rank] else: dx = dout @@ -326,7 +334,6 @@ def get_bprop_mirror_mini_step_operator(self): mean_flag = self.mean_flag all_reduce = AllReduce(group=group) - all_gather = AllGather(group=group) mul = P.Mul() cast = P.Cast() @@ -345,8 +352,8 @@ def get_bprop_mirror_mini_step_operator(self): if mean_flag: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: - tmp = z + dout - real_grad = all_reduce(tmp) + z = F.depend(z, F.assign_add(z, dout)) + real_grad = all_reduce(z) dx = real_grad else: dx = dout @@ -354,32 +361,17 @@ def get_bprop_mirror_mini_step_operator(self): num = F.scalar_cast(dev_num, F.dtype(dx)) dx = mul(dx, cast(F.scalar_to_array(float_one/num), F.dtype(dx))) else: - if do_mirror: - indices = all_gather(dout.indices) - grad = all_gather(dout.values) - else: - indices = dout.indices - grad = dout.values - float_one = F.scalar_cast(1.0, F.dtype(grad)) - num = F.scalar_cast(dev_num, F.dtype(grad)) - grad = mul(grad, cast(F.scalar_to_array(float_one/num), F.dtype(grad))) - dx = RowTensor(indices, grad, dout.dense_shape) + dx = zeros_like(x) # The grad accumulation do not support row tensor now else: if F.issubclass_(F.typeof(dout), mstype.tensor): if do_mirror: - tmp = z + dout - real_grad = all_reduce(tmp) + z = F.depend(z, F.assign_add(z, dout)) + real_grad = all_reduce(z) dx = real_grad else: dx = dout else: - if do_mirror: - indices = all_gather(dout.indices) - grad = all_gather(dout.values) - else: - indices = dout.indices - grad = dout.values - dx = RowTensor(indices, grad, dout.dense_shape) + dx = zeros_like(x) # The grad accumulation do not support row tensor now return (dx, zeros_like(z)) return bprop diff --git a/mindspore/ops/operations/__init__.py b/mindspore/ops/operations/__init__.py index 91a3a4e769..762adcfd03 100644 --- a/mindspore/ops/operations/__init__.py +++ b/mindspore/ops/operations/__init__.py @@ -36,7 +36,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Stack, Unpack, Unsta Unique, GatherD, Identity, Range) from .comm_ops import (AllGather, AllReduce, _AlltoAll, AllSwap, ReduceScatter, Broadcast, _MirrorOperator, _MirrorMiniStepOperator, _MiniStepAllGather, ReduceOp, _VirtualDataset, - _VirtualDiv, _GetTensorSlice, + _VirtualDiv, _GetTensorSlice, _VirtualAdd, _HostAllGather, _HostReduceScatter) from .debug_ops import (ImageSummary, InsertGradientOf, HookBackward, ScalarSummary, TensorSummary, HistogramSummary, Print, Assert) diff --git a/mindspore/ops/operations/comm_ops.py b/mindspore/ops/operations/comm_ops.py index 13ee156730..a5834ea66c 100644 --- a/mindspore/ops/operations/comm_ops.py +++ b/mindspore/ops/operations/comm_ops.py @@ -653,6 +653,19 @@ class _VirtualDiv(PrimitiveWithInfer): virtual_div = _VirtualDiv() +class _VirtualAdd(PrimitiveWithInfer): + """Auto parallel virtual operator. Do nothing in forward, do Add in backward.""" + @prim_attr_register + def __init__(self): + """init""" + + def infer_shape(self, x_shape, y_shape): + return x_shape + + def infer_dtype(self, x_dtype, y_dtype): + return x_dtype + + class _VirtualDataset(PrimitiveWithInfer): """ Auto parallel virtual dataset operator. diff --git a/model_zoo/official/nlp/gpt/src/gpt.py b/model_zoo/official/nlp/gpt/src/gpt.py index b5c5b7d900..80fec16060 100644 --- a/model_zoo/official/nlp/gpt/src/gpt.py +++ b/model_zoo/official/nlp/gpt/src/gpt.py @@ -25,6 +25,7 @@ from mindspore.common.initializer import TruncatedNormal, initializer, Normal from mindspore.ops import operations as P from mindspore.ops import functional as F + class LayerNorm(nn.Cell): """ Layer Normalization diff --git a/tests/ut/python/parallel/test_get_parameter_layout.py b/tests/ut/python/parallel/test_get_parameter_layout.py index 991e0300c7..6c3100390a 100644 --- a/tests/ut/python/parallel/test_get_parameter_layout.py +++ b/tests/ut/python/parallel/test_get_parameter_layout.py @@ -47,6 +47,7 @@ def test_get_parameter_layout(): net = Net(strategy1, strategy2, weight) net.set_auto_parallel() + net.set_train() exe = me._executor exe.compile(net, x, phase='train', auto_parallel_mode=True) x_layout = ([2, 4], [1, -1], [16, 32], 0, True, '') # device_arrangement = [2, 4], tensor_map = [1, -1] diff --git a/tests/ut/python/parallel/test_grad_accumulation.py b/tests/ut/python/parallel/test_grad_accumulation.py deleted file mode 100644 index 3b1bb48b73..0000000000 --- a/tests/ut/python/parallel/test_grad_accumulation.py +++ /dev/null @@ -1,307 +0,0 @@ -# Copyright 2021 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -import numpy as np - -import mindspore as ms -import mindspore.common.dtype as mstype -from mindspore import context, Tensor, Parameter -from mindspore.train import Model -from mindspore.ops import operations as P -from mindspore.ops import composite as C -from mindspore.ops import functional as F -from mindspore.common.initializer import initializer -from mindspore.context import ParallelMode -from mindspore.nn import DistributedGradReducer, DynamicLossScaleUpdateCell, Cell, Momentum, Norm -from mindspore.parallel._utils import _get_device_num -from tests.dataset_mock import MindData - - -class Dataset(MindData): - def __init__(self, predict, label, length=3): - super(Dataset, self).__init__(size=length) - self.predict = predict - self.label = label - self.index = 0 - self.length = length - - def __iter__(self): - return self - - def __next__(self): - if self.index >= self.length: - raise StopIteration - self.index += 1 - return self.predict, self.label - - def reset(self): - self.index = 0 - - -get_square_sum = C.MultitypeFuncGraph("get_square_sum") -@get_square_sum.register("Tensor") -def _get_square_sum(grad): - norm = P.ReduceSum(False)(F.square(grad), ()) - norm = F.expand_dims(F.cast(norm, mstype.float32), 0) - return norm - - -apply_global_norm = C.MultitypeFuncGraph("apply_global_norm") -@apply_global_norm.register("Tensor", "Tensor", "Tensor") -def _apply_global_norm(clip_norm, global_norm, grad): - grad = grad * clip_norm / global_norm - return grad - - -class GlobalNorm(Cell): - """ - Calculate the global norm value of given tensors - """ - def __init__(self): - super(GlobalNorm, self).__init__() - self.norm = Norm() - self.hyper_map = C.HyperMap() - - def construct(self, grads): - square_sum = self.hyper_map(get_square_sum, grads) - global_norms = F.sqrt(F.addn(square_sum) / F.scalar_to_array(len(square_sum))) - return global_norms - - -class ClipByGlobalNorm(Cell): - """ - Clip grads by global norm - """ - def __init__(self, clip_norm=1.0): - super(ClipByGlobalNorm, self).__init__() - self.global_norm = GlobalNorm() - self.clip_norm = Tensor([clip_norm], mstype.float32) - self.hyper_map = C.HyperMap() - - def construct(self, grads): - global_norm = self.global_norm(grads) - cond = P.GreaterEqual()(global_norm, self.clip_norm) - global_norm = F.select(cond, global_norm, self.clip_norm) - grads = self.hyper_map(F.partial(apply_global_norm, self.clip_norm, global_norm), grads) - return grads - - -cast = P.Cast() -update_accu_grads = C.MultitypeFuncGraph("update_accu_grads") - - -@update_accu_grads.register("Tensor", "Tensor") -def _update_accu_grads(accu_grad, grad): - succ = True - return F.depend(succ, F.assign_add(accu_grad, cast(grad, mstype.float32))) - - -zeroslike = P.ZerosLike() -reset_accu_grads = C.MultitypeFuncGraph("reset_accu_grads") - - -@reset_accu_grads.register("Tensor") -def _reset_accu_grads(accu_grad): - succ = True - return F.depend(succ, F.assign(accu_grad, zeroslike(accu_grad))) - - -grad_scale = C.MultitypeFuncGraph("grad_scale") -reciprocal = P.Reciprocal() - - -@grad_scale.register("Tensor", "Tensor") -def tensor_grad_scale(scale, grad): - return grad * reciprocal(scale) - - -class TrainAccumulateStepsWithLossScaleCell(Cell): - """ - Encapsulation class of bert network training. - - Append an optimizer to the training network after that the construct - function can be called to create the backward graph. To mimic higher batch size, gradients are - accumulated N times before weight update. - - Args: - network (Cell): The training network. Note that loss function should have been added. - optimizer (Optimizer): Optimizer for updating the weights. - scale_update_cell (Cell): Cell to do the loss scale. Default: None. - accumulation_steps (int): Number of accumulation steps before gradient update. The global batch size = - batch_size * accumulation_steps. Default: 1. - """ - def __init__(self, network, optimizer, scale_update_cell=None): - super(TrainAccumulateStepsWithLossScaleCell, self).__init__(auto_prefix=False) - self.accu = False - self.is_accu_step = Tensor(np.array([self.accu])) - self.network = network - self.network.set_grad() - self.weights = optimizer.parameters - self.optimizer = optimizer - self.accumulation_steps = context.get_auto_parallel_context("grad_accumulation_step") - self.one = Tensor(np.array([1]).astype(np.int32)) - self.zero = Tensor(np.array([0]).astype(np.int32)) - self.accu_grads = self.weights.clone(prefix="accu_grads", init='zeros') - self.accu_overflow = Parameter(initializer(0, [1], mstype.int32)) - self.accu_loss = Parameter(initializer(0, [1], mstype.float32)) - self.reducer_flag = False - self.grad = C.GradOperation(get_by_list=True, sens_param=True) - self.parallel_mode = context.get_auto_parallel_context("parallel_mode") - if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]: - self.reducer_flag = True - self.degree = 1 - self.grad_reducer = F.identity - if self.reducer_flag: - self.degree = _get_device_num() - self.grad_reducer = DistributedGradReducer(optimizer.parameters, False, self.degree) - self.is_distributed = (self.parallel_mode != ParallelMode.STAND_ALONE) - self.overflow_reducer = F.identity - if self.is_distributed: - self.overflow_reducer = P.AllReduce() - self.cast = P.Cast() - self.alloc_status = P.NPUAllocFloatStatus() - self.get_status = P.NPUGetFloatStatus() - self.clear_before_grad = P.NPUClearFloatStatus() - self.reduce_sum = P.ReduceSum(keep_dims=False) - self.base = Tensor(1, mstype.float32) - self.less_equal = P.LessEqual() - self.logical_or = P.LogicalOr() - self.not_equal = P.NotEqual() - self.select = P.Select() - self.reshape = P.Reshape() - self.hyper_map = C.HyperMap() - self.loss_scale = None - self.loss_scaling_manager = scale_update_cell - if scale_update_cell: - self.loss_scale = Parameter(Tensor(scale_update_cell.get_loss_scale(), dtype=mstype.float32)) - - @C.add_flags(has_effect=True) - def construct(self, x, b, sens=None): - """Defines the computation performed.""" - weights = self.weights - loss = self.network(x, b) - if sens is None: - scaling_sens = self.loss_scale - else: - scaling_sens = sens - - # alloc status and clear should be right before gradoperation - init = self.alloc_status() - self.clear_before_grad(init) - grads = self.grad(self.network, weights)(x, b, self.cast(scaling_sens, mstype.float32)) - - if self.is_accu_step and self.accumulation_steps > 1: - accu_succ = self.hyper_map(update_accu_grads, self.accu_grads, grads) - loss = F.depend(loss, accu_succ) - - self.get_status(init) - flag_sum = self.reduce_sum(init, (0,)) - overflow = self.less_equal(self.base, flag_sum) - overflow = self.logical_or(self.not_equal(self.accu_overflow, self.zero), overflow) - accu_overflow = self.select(overflow, self.one, self.zero) - self.accu_overflow = self.select(self.is_accu_step, accu_overflow, self.zero) - - if self.is_accu_step: - succ = False - else: - # apply grad reducer on grads - grads = self.grad_reducer(grads) - scaling = scaling_sens * self.degree * self.accumulation_steps - grads = self.hyper_map(F.partial(grad_scale, scaling), grads) - grads = ClipByGlobalNorm()(grads) - accu_overflow = self.overflow_reducer(accu_overflow) - F.control_depend(grads, accu_overflow) - overflow = self.less_equal(self.base, accu_overflow) - accu_succ = self.hyper_map(reset_accu_grads, self.accu_grads) - overflow = F.depend(overflow, accu_succ) - overflow = self.reshape(overflow, (())) - if sens is None: - overflow = self.loss_scaling_manager(self.loss_scale, overflow) - if overflow: - succ = False - else: - succ = self.optimizer(grads) - - ret = (loss, overflow, scaling_sens) - return F.depend(ret, succ) - - -class Net(Cell): - def __init__(self, weight, strategy=None): - super().__init__() - self.mul = P.Mul().shard(strategy) - self.weight = Parameter(weight, "w1") - self.relu = P.ReLU() - self.reduce_sum = P.ReduceSum(keep_dims=True) - - def construct(self, x, b): - out = self.mul(x, self.weight) - out = self.relu(out) - out = self.reduce_sum(out) - return out - - -_x = Tensor(np.ones([2]), dtype=ms.float32) -_b = Tensor(np.ones([16]), dtype=ms.float32) -_w1 = Tensor(np.ones([16]), dtype=ms.float32) - - -def compile_net(net): - context.set_context(enable_sparse=False) - learning_rate = 0.1 - momentum = 0.9 - epoch_size = 2 - dataset = Dataset(_x, _b) - opt = Momentum(net.trainable_params(), learning_rate, momentum) - update_cell = DynamicLossScaleUpdateCell(loss_scale_value=65536, scale_factor=2, scale_window=1000) - net_wrap = TrainAccumulateStepsWithLossScaleCell(net, opt, scale_update_cell=update_cell) - model = Model(net_wrap) - model.train(epoch_size, dataset, dataset_sink_mode=False) - context.reset_auto_parallel_context() - - -def test_grad_accumulation_accu(): - grad_accumulation_step = 4 - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, - grad_accumulation_step=grad_accumulation_step) - strategy = ((2,), (2,)) - net = Net(_w1, strategy).add_flags_recursive(accu=True) - compile_net(net) - - -def test_grad_accu_and_opt_shard_accu(): - grad_accumulation_step = 4 - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, - grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) - strategy = ((2,), (2,)) - net = Net(_w1, strategy).add_flags_recursive(accu=True) - compile_net(net) - - -def test_grad_accumulation_not_accu(): - grad_accumulation_step = 4 - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, - grad_accumulation_step=grad_accumulation_step) - strategy = ((2,), (2,)) - net = Net(_w1, strategy).add_flags_recursive(accu=False) - compile_net(net) - - -def test_grad_accu_and_opt_shard_not_accu(): - grad_accumulation_step = 4 - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel", device_num=8, global_rank=0, - grad_accumulation_step=grad_accumulation_step, enable_parallel_optimizer=True) - strategy = ((2,), (2,)) - net = Net(_w1, strategy).add_flags_recursive(accu=False) - compile_net(net)