From 7303c3d3b8de4d44bb004eab2c974ea97a91129b Mon Sep 17 00:00:00 2001 From: yangzhenzhang Date: Sat, 9 Jan 2021 15:32:15 +0800 Subject: [PATCH] add group ckpt --- mindspore/ccsrc/frontend/parallel/context.cc | 4 ++ mindspore/ccsrc/frontend/parallel/context.h | 3 + .../ccsrc/frontend/parallel/device_manager.h | 1 + .../ccsrc/frontend/parallel/dynamic_creator.h | 10 ++-- .../ccsrc/frontend/parallel/group_manager.cc | 26 +++++++++ .../ccsrc/frontend/parallel/group_manager.h | 5 ++ .../parallel/ops_info/reduce_method_info.cc | 8 +-- .../ccsrc/frontend/parallel/step_parallel.cc | 48 +++++++++++----- .../ccsrc/frontend/parallel/step_parallel.h | 4 +- .../parallel_strategy_checkpoint.cc | 57 +++++++++++++++++++ .../parallel_strategy_checkpoint.h | 8 +++ mindspore/ccsrc/pipeline/jit/init.cc | 1 + mindspore/ccsrc/utils/node_strategy.proto | 13 +++++ mindspore/parallel/_auto_parallel_context.py | 13 ++++- .../parallel_strategy_checkpoint_stub.cc | 4 ++ .../parallel/test_strategy_checkpoint.py | 6 +- 16 files changed, 184 insertions(+), 27 deletions(-) diff --git a/mindspore/ccsrc/frontend/parallel/context.cc b/mindspore/ccsrc/frontend/parallel/context.cc index d24e10b0e9..c03c64ce30 100644 --- a/mindspore/ccsrc/frontend/parallel/context.cc +++ b/mindspore/ccsrc/frontend/parallel/context.cc @@ -124,6 +124,10 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck strategy_ckpt_save_file_ = strategy_ckpt_save_file; } +void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) { + group_ckpt_save_file_ = group_ckpt_save_file; +} + void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector indices, const std::string &group) { all_reduce_fusion_split_indices_[group] = indices; } diff --git a/mindspore/ccsrc/frontend/parallel/context.h b/mindspore/ccsrc/frontend/parallel/context.h index 4f964bc479..d4212dde42 100644 --- a/mindspore/ccsrc/frontend/parallel/context.h +++ b/mindspore/ccsrc/frontend/parallel/context.h @@ -102,6 +102,8 @@ class ParallelContext { std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } + void set_group_ckpt_save_file(const std::string &group_ckpt_save_file); + std::string group_ckpt_save_file() const { return group_ckpt_save_file_; } void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { enable_parallel_optimizer_ = enable_parallel_optimizer; @@ -132,6 +134,7 @@ class ParallelContext { std::map> all_reduce_fusion_split_sizes_; std::string strategy_ckpt_load_file_; std::string strategy_ckpt_save_file_; + std::string group_ckpt_save_file_; bool enable_parallel_optimizer_; }; diff --git a/mindspore/ccsrc/frontend/parallel/device_manager.h b/mindspore/ccsrc/frontend/parallel/device_manager.h index 60432484bc..1d3d711d38 100644 --- a/mindspore/ccsrc/frontend/parallel/device_manager.h +++ b/mindspore/ccsrc/frontend/parallel/device_manager.h @@ -83,6 +83,7 @@ class DeviceManager { void Clear(); std::string world_group() const { return gm_.world_group(); } + std::vector>> group_info() const { return gm_.group_info(); } std::string FindRankListNameByHashName(const std::string &hash_name); private: diff --git a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h index acf8ecdbd8..43108ea6e1 100644 --- a/mindspore/ccsrc/frontend/parallel/dynamic_creator.h +++ b/mindspore/ccsrc/frontend/parallel/dynamic_creator.h @@ -40,16 +40,16 @@ class DynCreator { public: ~DynCreator() = default; - // creat static singleton dyn_creator instance + // create static singleton dyn_creator instance static DynCreator &Instance() { static DynCreator fac = DynCreator(); return fac; } // register - void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } + void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } // creator - OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, - const PrimitiveAttrs &attrs, size_t count) { + OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, + const PrimitiveAttrs &attrs, size_t count) { std::string op_name = name + std::to_string(count); auto iter = Function_map_.find(name); if (iter == Function_map_.end()) { @@ -67,7 +67,7 @@ class DynCreator { class RegisterAction { public: RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { - DynCreator::Instance().Regist(name, creatfn); + DynCreator::Instance().Register(name, creatfn); } ~RegisterAction() = default; diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.cc b/mindspore/ccsrc/frontend/parallel/group_manager.cc index a57cb3b72e..58eb60b284 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.cc +++ b/mindspore/ccsrc/frontend/parallel/group_manager.cc @@ -17,6 +17,7 @@ #include "frontend/parallel/group_manager.h" #include #include +#include #include "backend/session/executor_manager.h" #include "frontend/parallel/device_manager.h" #include "utils/comm_manager.h" @@ -109,6 +110,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto return Status::FAILED; } + std::pair> group_info = std::make_pair(group_name, ranks); + group_info_.push_back(group_info); + MS_LOG(INFO) << "Create group success, group name is " << group_name; return Status::SUCCESS; } @@ -187,5 +191,27 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro } void GroupManager::Clear() { (void)DestroyAllGroups(); } + +Status CreateGroups(const std::vector>> &group_info) { + // Create group through the executor + auto context_ptr = MsContext::GetInstance(); + MS_EXCEPTION_IF_NULL(context_ptr); + std::string device_name = context_ptr->get_param(MS_CTX_DEVICE_TARGET); + uint32_t device_id = context_ptr->get_param(MS_CTX_DEVICE_ID); + auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id); + MS_EXCEPTION_IF_NULL(executor); + + for (auto &group : group_info) { + bool ret = executor->CreateCommGroup(group.first, group.second); + if (!ret) { + MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second; + return FAILED; + } + MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second; + } + + return SUCCESS; +} + } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/group_manager.h b/mindspore/ccsrc/frontend/parallel/group_manager.h index eeb99a6751..3c106e8624 100644 --- a/mindspore/ccsrc/frontend/parallel/group_manager.h +++ b/mindspore/ccsrc/frontend/parallel/group_manager.h @@ -21,6 +21,7 @@ #include #include #include +#include #include "frontend/parallel/device.h" #include "frontend/parallel/status.h" @@ -62,6 +63,7 @@ class GroupManager { Status FindGroup(const std::string &name, Group **group); std::string world_group() const { return world_group_; } void set_world_group(const std::string &name) { world_group_ = name; } + std::vector>> group_info() const { return group_info_; } void Clear(); private: @@ -69,7 +71,10 @@ class GroupManager { // the key is group name (name_) std::map groups_; std::string world_group_; + std::vector>> group_info_; }; + +Status CreateGroups(const std::vector>> &group_info); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc index db8b81643e..7e46af83ba 100644 --- a/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc +++ b/mindspore/ccsrc/frontend/parallel/ops_info/reduce_method_info.cc @@ -160,7 +160,7 @@ Status ReduceMethod::InferForwardCommunication() { Shape group_creat_map; // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, - // it need to handle the first dimention of map. + // it need to handle the first dimension of map. if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); } @@ -200,12 +200,12 @@ Status ReduceMethod::InferForwardCommunication() { } ForwardOp CreateReduceMeanForwardOp(const std::vector &forward_group, const TypePtr &dtype) { - // Creat AllReduceSum op + // Create AllReduceSum op Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); std::string group_name = forward_group[0].name(); MS_LOG(INFO) << "The group of forward all reduce is " << group_name; - // Creat RealDiv op + // Create RealDiv op OperatorName operator1_name = REAL_DIV; std::vector device_list = forward_group[0].GetDevicesList(); auto divisor = static_cast(device_list.size()); @@ -237,7 +237,7 @@ Status ReduceMeanInfo::InferForwardCommunication() { Shape group_creat_map; // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, - // it need to handle the first dimention of map. + // it need to handle the first dimension of map. if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); } diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.cc b/mindspore/ccsrc/frontend/parallel/step_parallel.cc index 910e9f9725..436d651394 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.cc +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.cc @@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) { std::string instance_name_base = FORWARD_OP; std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); std::vector forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); - CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode + CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode MS_EXCEPTION_IF_NULL(forward_node); ScopePtr scope = node->scope(); MS_EXCEPTION_IF_NULL(scope); @@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p if (pos >= SizeToLong(node->inputs().size())) { MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; } - // Creat new node + // Create new node AnfNodePtr target_node = node->input(LongToSize(pos)); MS_EXCEPTION_IF_NULL(target_node); - // Creat instance_name + // Create instance_name auto op = (redistribution_oplist_ptr->first)[index]; std::string op_name = (redistribution_oplist_ptr->first)[index].first; std::string instance_name_base = REDISTRIBUTION_OP; @@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " << instance_name; } - // Creat new node + // Create new node AnfNodePtr pre_node = node->input(LongToSize(pos)); MS_EXCEPTION_IF_NULL(pre_node); InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name); @@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_ CNodePtr insert_node_new; if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) { - MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node"; + MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node"; return; } if (IsValueNode(node->input(0))) { @@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node if (manager == nullptr) { MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; } - // Sovle the input order + // Solve the input order // For example input_node:{segment_sum:1, segment_sum:2, gahter:2} - // The Original code here will bind the all operations to the first inputs of theses operatos - // However, the segment_sum operation needs two inputs, To sovle this + // The Original code here will bind the all operations to the first inputs of these operatos + // However, the segment_sum operation needs two inputs, To solve this // We maintain a dict to count the times of the same operations, // and bind the inputs according to the times of the op appears. static std::unordered_map input_map = {}; @@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA } } OperatorInfoPtr operator_ = - (OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); + (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); if (operator_ == nullptr) { - MS_LOG(INFO) << "Creat " << name << " failed"; + MS_LOG(INFO) << "Create " << name << " failed"; return nullptr; } std::string origin_name = operator_->name(); @@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs if (IsInBatchParallelBlackList(prim)) { MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode."; } - MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; + MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel"; operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); MS_EXCEPTION_IF_NULL(operator_); } @@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) { } if (cnode->input(0)->isa()) { if (cnode->inputs().size() < 2) { - MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; + MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2"; } base_shape_ptr = cnode->input(1)->Shape(); } @@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector &all_nodes) { auto param_split_shapes = gatherv2_info->param_split_shapes(); auto index_offsets = gatherv2_info->index_offsets(); if (param_split_shapes.size() != index_offsets.size()) { - MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; + MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same."; } std::vector> manual_shape; for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) { @@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector &all_nodes) { } } } + if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; } @@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector &all_nodes) { } } +bool CreateGroupsByCkptFile(const std::string &file) { + GroupInfoMap group_info_map; + if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) { + return false; + } + + if (CreateGroups(group_info_map) != SUCCESS) { + return false; + } + MS_LOG(INFO) << "Create groups by checkpoint file success"; + return true; +} + bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) { MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(parameter); @@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer) // ForwardCommunication BackwardCommunication TensorRedistribution ParallelCommunication(root, all_nodes, manager); + auto group_info = g_device_manager->group_info(); + if (StrategyCheckpoint::GetInstance().group_info_save_on() && + StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) { + MS_LOG(EXCEPTION) << "Save group info failed"; + } + DumpGraph(root, std::string(STEP_PARALLEL_END)); // step parallel only run once diff --git a/mindspore/ccsrc/frontend/parallel/step_parallel.h b/mindspore/ccsrc/frontend/parallel/step_parallel.h index 33a22a3b77..2926bba2dd 100644 --- a/mindspore/ccsrc/frontend/parallel/step_parallel.h +++ b/mindspore/ccsrc/frontend/parallel/step_parallel.h @@ -109,7 +109,7 @@ void CoverSliceShape(const FuncGraphPtr &root); void SetVirtualDatasetStrategy(const CNodePtr &node); -// Creat parallel operator for primitive node(has strategy) +// Create parallel operator for primitive node(has strategy) void ExtractInformation(const std::vector &all_nodes, bool is_training = true); TensorLayout GetInputLayoutFromCNode(const std::pair &node_pair); @@ -163,6 +163,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr void SetLastNodeStrategy(const StrategyPtr strategyPtr); +bool CreateGroupsByCkptFile(const std::string &file); + void FindLastNodesUniqueId(const std::vector &all_nodes, std::vector *unique_ids); } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc index 5260f68f37..64e8614a0b 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.cc @@ -34,6 +34,8 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() { instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); + instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file(); + instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty(); } return instance; } @@ -46,6 +48,39 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const { return false; } +Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { + MS_EXCEPTION_IF_NULL(group_info_map); + if (!CheckPointExit(file)) { + MS_LOG(EXCEPTION) << "CheckPoint file is not found"; + } + straspb::ParallelGroupMap parallel_group_map; + std::fstream input(file, std::ios::in | std::ios::binary); + if (!parallel_group_map.ParseFromIstream(&input)) { + MS_LOG(ERROR) << "Load strategy file failed"; + return FAILED; + } + input.close(); + + size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size()); + for (size_t i = 0; i < group_num; ++i) { + straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i)); + std::string group_name = parallel_group_item.group_name(); + + straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks(); + size_t rank_num = LongToSize(parallel_group_ranks.dim_size()); + std::vector ranks; + for (size_t j = 0; j < rank_num; ++j) { + uint32_t rank = parallel_group_ranks.dim(SizeToLong(j)); + ranks.push_back(rank); + } + + std::pair> group = std::make_pair(group_name, ranks); + group_info_map->push_back(group); + } + + return SUCCESS; +} + Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { if (strategy_map == nullptr) { MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; @@ -141,5 +176,27 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf output.close(); return SUCCESS; } + +Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { + straspb::ParallelGroupMap parallel_group_map; + for (auto &group : group_info_map) { + straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item(); + MS_EXCEPTION_IF_NULL(parallel_group_item); + parallel_group_item->set_group_name(group.first); + straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks(); + MS_EXCEPTION_IF_NULL(parallel_group_ranks); + for (auto &rank : group.second) { + parallel_group_ranks->add_dim(rank); + } + } + + std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary); + if (!parallel_group_map.SerializeToOstream(&output)) { + MS_LOG(ERROR) << "Save strategy file failed"; + return FAILED; + } + output.close(); + return SUCCESS; +} } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h index 9dc71cdded..5048f21ab9 100644 --- a/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h +++ b/mindspore/ccsrc/frontend/parallel/strategy_checkpoint/parallel_strategy_checkpoint.h @@ -32,6 +32,7 @@ namespace parallel { using StrategyMap = std::unordered_map; using TensorInfoMap = std::unordered_map; using ManualShapeMap = std::unordered_map>>; +using GroupInfoMap = std::vector>>; class StrategyCheckpoint { public: StrategyCheckpoint() { @@ -40,11 +41,16 @@ class StrategyCheckpoint { load_checkpoint_on_ = false; save_file_ = ""; save_checkpoint_on_ = false; + group_info_save_file_ = ""; + group_info_save_on_ = false; } ~StrategyCheckpoint() = default; Status Load(StrategyMap *strategy_map); + Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map); Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); + Status SaveGroupInfo(const GroupInfoMap &group_info_map); + bool group_info_save_on() const { return group_info_save_on_; } static StrategyCheckpoint &GetInstance(); bool LoadCheckPointOn() const { return load_checkpoint_on_; } @@ -57,6 +63,8 @@ class StrategyCheckpoint { bool save_checkpoint_on_; bool CheckPointExit(const std::string path) const; int64_t current_stage_; + std::string group_info_save_file_; + bool group_info_save_on_; }; } // namespace parallel } // namespace mindspore diff --git a/mindspore/ccsrc/pipeline/jit/init.cc b/mindspore/ccsrc/pipeline/jit/init.cc index da92fc3258..5d56c5b8fd 100644 --- a/mindspore/ccsrc/pipeline/jit/init.cc +++ b/mindspore/ccsrc/pipeline/jit/init.cc @@ -157,6 +157,7 @@ PYBIND11_MODULE(_c_expression, m) { "Set strategy checkpoint save file.") .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") + .def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.") .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num, "Set pipeline stage split num.") .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.") diff --git a/mindspore/ccsrc/utils/node_strategy.proto b/mindspore/ccsrc/utils/node_strategy.proto index daa51880d6..ffb8f4d87d 100644 --- a/mindspore/ccsrc/utils/node_strategy.proto +++ b/mindspore/ccsrc/utils/node_strategy.proto @@ -61,6 +61,19 @@ message ParallelLayoutItem { required ParallelLayouts parallel_layouts = 2; } +message ParallelGroupRanks { + repeated uint32 dim = 1; +} + +message ParallelGroupItem { + required string group_name = 1; + required ParallelGroupRanks parallel_group_ranks = 2; +} + +message ParallelGroupMap { + repeated ParallelGroupItem parallel_group_item = 1; +} + message ParallelStrategyMap { required uint32 current_stage = 1; repeated ParallelStrategyItem parallel_strategy_item = 2; diff --git a/mindspore/parallel/_auto_parallel_context.py b/mindspore/parallel/_auto_parallel_context.py index d97daca266..ebc5dcd153 100644 --- a/mindspore/parallel/_auto_parallel_context.py +++ b/mindspore/parallel/_auto_parallel_context.py @@ -283,6 +283,15 @@ class _AutoParallelContext: self.check_context_handle() return self._context_handle.get_strategy_ckpt_save_file() + def set_group_ckpt_save_file(self, group_ckpt_save_file): + """Set group checkpoint save path.""" + self.check_context_handle() + import os + dir_path = os.path.dirname(group_ckpt_save_file) + if dir_path and not os.path.exists(dir_path): + os.makedirs(dir_path) + self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file) + def get_parameter_broadcast_is_set(self): """Get parameter broadcast is set or not.""" self.check_context_handle() @@ -505,6 +514,7 @@ _set_auto_parallel_context_func_map = { "parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, + "group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file, "full_batch": auto_parallel_context().set_full_batch, "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, @@ -533,7 +543,7 @@ _get_auto_parallel_context_func_map = { loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, parameter_broadcast=bool, strategy_ckpt_load_file=str, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, - grad_accumulation_step=int, all_reduce_fusion_config=list) + grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str) def _set_auto_parallel_context(**kwargs): """ @@ -574,6 +584,7 @@ def _set_auto_parallel_context(**kwargs): broadcast. Default: False. strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' + group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: '' full_batch (bool): Whether to load the whole batch on each device. Default: False. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. diff --git a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc index 6ae883cfbd..6e88549f8c 100644 --- a/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc +++ b/tests/ut/cpp/stub/parallel_strategy_checkpoint/parallel_strategy_checkpoint_stub.cc @@ -31,5 +31,9 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; } Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map) { return SUCCESS; } + +Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; } + +Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; } } // namespace parallel } // namespace mindspore diff --git a/tests/ut/python/parallel/test_strategy_checkpoint.py b/tests/ut/python/parallel/test_strategy_checkpoint.py index 9a2db97951..d6bb9eeffd 100644 --- a/tests/ut/python/parallel/test_strategy_checkpoint.py +++ b/tests/ut/python/parallel/test_strategy_checkpoint.py @@ -75,7 +75,8 @@ def test_six_matmul_save(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt") + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt", + group_ckpt_save_file="./group_stage1.ckpt") strategy1 = ((8, 1), (1, 1)) strategy2 = ((1, 8), (8, 1)) strategy3 = ((2, 2), (2, 2)) @@ -137,7 +138,8 @@ def test_six_matmul_load(): return out reset_auto_parallel_context() - set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt") + set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt", + group_ckpt_save_file="./group_stage1.ckpt") strategy1 = ((8, 1), (1, 1)) strategy3 = ((8, 1), (1, 1)) strategy4 = ((8, 1), (1, 1))