!11472 support to checkpoint group info

From: @yangzhenzhang
Reviewed-by: @stsuteng,@kisnwang
Signed-off-by: @stsuteng
pull/11472/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 0641940b87

@ -124,6 +124,10 @@ void ParallelContext::set_strategy_ckpt_save_file(const std::string &strategy_ck
strategy_ckpt_save_file_ = strategy_ckpt_save_file; strategy_ckpt_save_file_ = strategy_ckpt_save_file;
} }
void ParallelContext::set_group_ckpt_save_file(const std::string &group_ckpt_save_file) {
group_ckpt_save_file_ = group_ckpt_save_file;
}
void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) { void ParallelContext::SetAllReduceFusionSplitIndices(const std::vector<uint32_t> indices, const std::string &group) {
all_reduce_fusion_split_indices_[group] = indices; all_reduce_fusion_split_indices_[group] = indices;
} }

@ -102,6 +102,8 @@ class ParallelContext {
std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; } std::string strategy_ckpt_load_file() const { return strategy_ckpt_load_file_; }
void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file); void set_strategy_ckpt_save_file(const std::string &strategy_ckpt_save_file);
std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; } std::string strategy_ckpt_save_file() const { return strategy_ckpt_save_file_; }
void set_group_ckpt_save_file(const std::string &group_ckpt_save_file);
std::string group_ckpt_save_file() const { return group_ckpt_save_file_; }
void set_enable_parallel_optimizer(bool enable_parallel_optimizer) { void set_enable_parallel_optimizer(bool enable_parallel_optimizer) {
enable_parallel_optimizer_ = enable_parallel_optimizer; enable_parallel_optimizer_ = enable_parallel_optimizer;
@ -132,6 +134,7 @@ class ParallelContext {
std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_; std::map<std::string, std::vector<uint32_t>> all_reduce_fusion_split_sizes_;
std::string strategy_ckpt_load_file_; std::string strategy_ckpt_load_file_;
std::string strategy_ckpt_save_file_; std::string strategy_ckpt_save_file_;
std::string group_ckpt_save_file_;
bool enable_parallel_optimizer_; bool enable_parallel_optimizer_;
}; };

@ -83,6 +83,7 @@ class DeviceManager {
void Clear(); void Clear();
std::string world_group() const { return gm_.world_group(); } std::string world_group() const { return gm_.world_group(); }
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return gm_.group_info(); }
std::string FindRankListNameByHashName(const std::string &hash_name); std::string FindRankListNameByHashName(const std::string &hash_name);
private: private:

@ -40,16 +40,16 @@ class DynCreator {
public: public:
~DynCreator() = default; ~DynCreator() = default;
// creat static singleton dyn_creator instance // create static singleton dyn_creator instance
static DynCreator &Instance() { static DynCreator &Instance() {
static DynCreator fac = DynCreator(); static DynCreator fac = DynCreator();
return fac; return fac;
} }
// register // register
void Regist(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); } void Register(std::string name, CreatFn func) { (void)Function_map_.insert(std::make_pair(name, func)); }
// creator // creator
OperatorInfoPtr Creat(const std::string &name, const Shapes &shape_in, const Shapes &shape_out, OperatorInfoPtr Create(const std::string &name, const Shapes &shape_in, const Shapes &shape_out,
const PrimitiveAttrs &attrs, size_t count) { const PrimitiveAttrs &attrs, size_t count) {
std::string op_name = name + std::to_string(count); std::string op_name = name + std::to_string(count);
auto iter = Function_map_.find(name); auto iter = Function_map_.find(name);
if (iter == Function_map_.end()) { if (iter == Function_map_.end()) {
@ -67,7 +67,7 @@ class DynCreator {
class RegisterAction { class RegisterAction {
public: public:
RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) { RegisterAction(const std::string &name, CreatFn creatfn) : name_(name) {
DynCreator::Instance().Regist(name, creatfn); DynCreator::Instance().Register(name, creatfn);
} }
~RegisterAction() = default; ~RegisterAction() = default;

@ -17,6 +17,7 @@
#include "frontend/parallel/group_manager.h" #include "frontend/parallel/group_manager.h"
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include <utility>
#include "backend/session/executor_manager.h" #include "backend/session/executor_manager.h"
#include "frontend/parallel/device_manager.h" #include "frontend/parallel/device_manager.h"
#include "utils/comm_manager.h" #include "utils/comm_manager.h"
@ -109,6 +110,9 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
return Status::FAILED; return Status::FAILED;
} }
std::pair<std::string, std::vector<uint32_t>> group_info = std::make_pair(group_name, ranks);
group_info_.push_back(group_info);
MS_LOG(INFO) << "Create group success, group name is " << group_name; MS_LOG(INFO) << "Create group success, group name is " << group_name;
return Status::SUCCESS; return Status::SUCCESS;
} }
@ -187,5 +191,27 @@ Status GroupManager::FindGroup(const std::string &name, mindspore::parallel::Gro
} }
void GroupManager::Clear() { (void)DestroyAllGroups(); } void GroupManager::Clear() { (void)DestroyAllGroups(); }
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info) {
// Create group through the executor
auto context_ptr = MsContext::GetInstance();
MS_EXCEPTION_IF_NULL(context_ptr);
std::string device_name = context_ptr->get_param<std::string>(MS_CTX_DEVICE_TARGET);
uint32_t device_id = context_ptr->get_param<uint32_t>(MS_CTX_DEVICE_ID);
auto executor = session::ExecutorManager::Instance().GetExecutor(device_name, device_id);
MS_EXCEPTION_IF_NULL(executor);
for (auto &group : group_info) {
bool ret = executor->CreateCommGroup(group.first, group.second);
if (!ret) {
MS_LOG(ERROR) << "Create group failed, group name is " << group.first << ", ranks is " << group.second;
return FAILED;
}
MS_LOG(INFO) << "Create group success, group name is " << group.first << ", ranks is " << group.second;
}
return SUCCESS;
}
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -21,6 +21,7 @@
#include <map> #include <map>
#include <string> #include <string>
#include <vector> #include <vector>
#include <utility>
#include "frontend/parallel/device.h" #include "frontend/parallel/device.h"
#include "frontend/parallel/status.h" #include "frontend/parallel/status.h"
@ -62,6 +63,7 @@ class GroupManager {
Status FindGroup(const std::string &name, Group **group); Status FindGroup(const std::string &name, Group **group);
std::string world_group() const { return world_group_; } std::string world_group() const { return world_group_; }
void set_world_group(const std::string &name) { world_group_ = name; } void set_world_group(const std::string &name) { world_group_ = name; }
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info() const { return group_info_; }
void Clear(); void Clear();
private: private:
@ -69,7 +71,10 @@ class GroupManager {
// the key is group name (name_) // the key is group name (name_)
std::map<std::string, Group> groups_; std::map<std::string, Group> groups_;
std::string world_group_; std::string world_group_;
std::vector<std::pair<std::string, std::vector<uint32_t>>> group_info_;
}; };
Status CreateGroups(const std::vector<std::pair<std::string, std::vector<uint32_t>>> &group_info);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -160,7 +160,7 @@ Status ReduceMethod::InferForwardCommunication() {
Shape group_creat_map; Shape group_creat_map;
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map. // it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
} }
@ -200,12 +200,12 @@ Status ReduceMethod::InferForwardCommunication() {
} }
ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) { ForwardOp CreateReduceMeanForwardOp(const std::vector<Group> &forward_group, const TypePtr &dtype) {
// Creat AllReduceSum op // Create AllReduceSum op
Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name()); Operator op0 = CreateAllReduceOp(REDUCE_OP_SUM, forward_group[0].name());
std::string group_name = forward_group[0].name(); std::string group_name = forward_group[0].name();
MS_LOG(INFO) << "The group of forward all reduce is " << group_name; MS_LOG(INFO) << "The group of forward all reduce is " << group_name;
// Creat RealDiv op // Create RealDiv op
OperatorName operator1_name = REAL_DIV; OperatorName operator1_name = REAL_DIV;
std::vector<Device> device_list = forward_group[0].GetDevicesList(); std::vector<Device> device_list = forward_group[0].GetDevicesList();
auto divisor = static_cast<float>(device_list.size()); auto divisor = static_cast<float>(device_list.size());
@ -237,7 +237,7 @@ Status ReduceMeanInfo::InferForwardCommunication() {
Shape group_creat_map; Shape group_creat_map;
// if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix, // if repeated calculation and the repeated_calc_num_ insert to the first dimension of dev matrix,
// it need to handle the first dimention of map. // it need to handle the first dimension of map.
if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) { if ((dev_matrix_shape_.size() > size) && !repeated_num_in_dev_matrix_right_) {
group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1))); group_creat_map.push_back(SizeToInt(dev_matrix_shape_.size() - size_t(1)));
} }

@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
std::string instance_name_base = FORWARD_OP; std::string instance_name_base = FORWARD_OP;
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index); std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name); std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
MS_EXCEPTION_IF_NULL(forward_node); MS_EXCEPTION_IF_NULL(forward_node);
ScopePtr scope = node->scope(); ScopePtr scope = node->scope();
MS_EXCEPTION_IF_NULL(scope); MS_EXCEPTION_IF_NULL(scope);
@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
if (pos >= SizeToLong(node->inputs().size())) { if (pos >= SizeToLong(node->inputs().size())) {
MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size"; MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
} }
// Creat new node // Create new node
AnfNodePtr target_node = node->input(LongToSize(pos)); AnfNodePtr target_node = node->input(LongToSize(pos));
MS_EXCEPTION_IF_NULL(target_node); MS_EXCEPTION_IF_NULL(target_node);
// Creat instance_name // Create instance_name
auto op = (redistribution_oplist_ptr->first)[index]; auto op = (redistribution_oplist_ptr->first)[index];
std::string op_name = (redistribution_oplist_ptr->first)[index].first; std::string op_name = (redistribution_oplist_ptr->first)[index].first;
std::string instance_name_base = REDISTRIBUTION_OP; std::string instance_name_base = REDISTRIBUTION_OP;
@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is " MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
<< instance_name; << instance_name;
} }
// Creat new node // Create new node
AnfNodePtr pre_node = node->input(LongToSize(pos)); AnfNodePtr pre_node = node->input(LongToSize(pos));
MS_EXCEPTION_IF_NULL(pre_node); MS_EXCEPTION_IF_NULL(pre_node);
InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name); InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
CNodePtr insert_node_new; CNodePtr insert_node_new;
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) { if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node"; MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
return; return;
} }
if (IsValueNode<Primitive>(node->input(0))) { if (IsValueNode<Primitive>(node->input(0))) {
@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
if (manager == nullptr) { if (manager == nullptr) {
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr"; MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
} }
// Sovle the input order // Solve the input order
// For example input_node:{segment_sum:1, segment_sum:2, gahter:2} // For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
// The Original code here will bind the all operations to the first inputs of theses operatos // The Original code here will bind the all operations to the first inputs of these operatos
// However, the segment_sum operation needs two inputs, To sovle this // However, the segment_sum operation needs two inputs, To solve this
// We maintain a dict to count the times of the same operations, // We maintain a dict to count the times of the same operations,
// and bind the inputs according to the times of the op appears. // and bind the inputs according to the times of the op appears.
static std::unordered_map<AnfNodePtr, int> input_map = {}; static std::unordered_map<AnfNodePtr, int> input_map = {};
@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
} }
} }
OperatorInfoPtr operator_ = OperatorInfoPtr operator_ =
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS); (OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
if (operator_ == nullptr) { if (operator_ == nullptr) {
MS_LOG(INFO) << "Creat " << name << " failed"; MS_LOG(INFO) << "Create " << name << " failed";
return nullptr; return nullptr;
} }
std::string origin_name = operator_->name(); std::string origin_name = operator_->name();
@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
if (IsInBatchParallelBlackList(prim)) { if (IsInBatchParallelBlackList(prim)) {
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode."; MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
} }
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel"; MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list); operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
MS_EXCEPTION_IF_NULL(operator_); MS_EXCEPTION_IF_NULL(operator_);
} }
@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
} }
if (cnode->input(0)->isa<CNode>()) { if (cnode->input(0)->isa<CNode>()) {
if (cnode->inputs().size() < 2) { if (cnode->inputs().size() < 2) {
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2"; MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
} }
base_shape_ptr = cnode->input(1)->Shape(); base_shape_ptr = cnode->input(1)->Shape();
} }
@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
bool has_backward = !sens_loss_pairs.empty(); bool has_backward = !sens_loss_pairs.empty();
// split sens must before inserting the operators. // split sens must before inserting the operators.
for (auto &pair : sens_loss_pairs) { for (auto &pair : sens_loss_pairs) {
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it. // If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
// If the type of sens node is not Tensor, it is unsupported now, do nothing default. // If the type of sens node is not Tensor, it is unsupported now, do nothing default.
if (IsLastStage()) { if (IsLastStage()) {
StepSplitSens(pair); StepSplitSens(pair);
@ -2703,7 +2703,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
auto param_split_shapes = gatherv2_info->param_split_shapes(); auto param_split_shapes = gatherv2_info->param_split_shapes();
auto index_offsets = gatherv2_info->index_offsets(); auto index_offsets = gatherv2_info->index_offsets();
if (param_split_shapes.size() != index_offsets.size()) { if (param_split_shapes.size() != index_offsets.size()) {
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same."; MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
} }
std::vector<std::pair<int64_t, int64_t>> manual_shape; std::vector<std::pair<int64_t, int64_t>> manual_shape;
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) { for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
} }
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) { if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed"; MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
} }
@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
} }
} }
bool CreateGroupsByCkptFile(const std::string &file) {
GroupInfoMap group_info_map;
if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
return false;
}
if (CreateGroups(group_info_map) != SUCCESS) {
return false;
}
MS_LOG(INFO) << "Create groups by checkpoint file success";
return true;
}
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter) { bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr &parameter) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
MS_EXCEPTION_IF_NULL(parameter); MS_EXCEPTION_IF_NULL(parameter);
@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
// ForwardCommunication BackwardCommunication TensorRedistribution // ForwardCommunication BackwardCommunication TensorRedistribution
ParallelCommunication(root, all_nodes, manager); ParallelCommunication(root, all_nodes, manager);
auto group_info = g_device_manager->group_info();
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
MS_LOG(EXCEPTION) << "Save group info failed";
}
DumpGraph(root, std::string(STEP_PARALLEL_END)); DumpGraph(root, std::string(STEP_PARALLEL_END));
// step parallel only run once // step parallel only run once

@ -109,7 +109,7 @@ void CoverSliceShape(const FuncGraphPtr &root);
void SetVirtualDatasetStrategy(const CNodePtr &node); void SetVirtualDatasetStrategy(const CNodePtr &node);
// Creat parallel operator for primitive node(has strategy) // Create parallel operator for primitive node(has strategy)
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true); void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training = true);
TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair); TensorLayout GetInputLayoutFromCNode(const std::pair<AnfNodePtr, int64_t> &node_pair);
@ -163,6 +163,8 @@ void ApplyParallelOptOnParam(TensorLayout *tensor_layout, const OperatorInfoPtr
void SetLastNodeStrategy(const StrategyPtr strategyPtr); void SetLastNodeStrategy(const StrategyPtr strategyPtr);
bool CreateGroupsByCkptFile(const std::string &file);
void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids); void FindLastNodesUniqueId(const std::vector<AnfNodePtr> &all_nodes, std::vector<std::string> *unique_ids);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -34,6 +34,8 @@ StrategyCheckpoint &StrategyCheckpoint::GetInstance() {
instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty(); instance.load_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_load_file().empty();
instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file(); instance.save_file_ = ParallelContext::GetInstance()->strategy_ckpt_save_file();
instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty(); instance.save_checkpoint_on_ = !ParallelContext::GetInstance()->strategy_ckpt_save_file().empty();
instance.group_info_save_file_ = ParallelContext::GetInstance()->group_ckpt_save_file();
instance.group_info_save_on_ = !ParallelContext::GetInstance()->group_ckpt_save_file().empty();
} }
return instance; return instance;
} }
@ -46,6 +48,39 @@ bool StrategyCheckpoint::CheckPointExit(const std::string path) const {
return false; return false;
} }
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) {
MS_EXCEPTION_IF_NULL(group_info_map);
if (!CheckPointExit(file)) {
MS_LOG(EXCEPTION) << "CheckPoint file is not found";
}
straspb::ParallelGroupMap parallel_group_map;
std::fstream input(file, std::ios::in | std::ios::binary);
if (!parallel_group_map.ParseFromIstream(&input)) {
MS_LOG(ERROR) << "Load strategy file failed";
return FAILED;
}
input.close();
size_t group_num = LongToSize(parallel_group_map.parallel_group_item_size());
for (size_t i = 0; i < group_num; ++i) {
straspb::ParallelGroupItem parallel_group_item = parallel_group_map.parallel_group_item(SizeToLong(i));
std::string group_name = parallel_group_item.group_name();
straspb::ParallelGroupRanks parallel_group_ranks = parallel_group_item.parallel_group_ranks();
size_t rank_num = LongToSize(parallel_group_ranks.dim_size());
std::vector<uint32_t> ranks;
for (size_t j = 0; j < rank_num; ++j) {
uint32_t rank = parallel_group_ranks.dim(SizeToLong(j));
ranks.push_back(rank);
}
std::pair<std::string, std::vector<uint32_t>> group = std::make_pair(group_name, ranks);
group_info_map->push_back(group);
}
return SUCCESS;
}
Status StrategyCheckpoint::Load(StrategyMap *strategy_map) { Status StrategyCheckpoint::Load(StrategyMap *strategy_map) {
if (strategy_map == nullptr) { if (strategy_map == nullptr) {
MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr"; MS_LOG(EXCEPTION) << "Failure:strategy_map is nullptr";
@ -141,5 +176,27 @@ Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInf
output.close(); output.close();
return SUCCESS; return SUCCESS;
} }
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) {
straspb::ParallelGroupMap parallel_group_map;
for (auto &group : group_info_map) {
straspb::ParallelGroupItem *parallel_group_item = parallel_group_map.add_parallel_group_item();
MS_EXCEPTION_IF_NULL(parallel_group_item);
parallel_group_item->set_group_name(group.first);
straspb::ParallelGroupRanks *parallel_group_ranks = parallel_group_item->mutable_parallel_group_ranks();
MS_EXCEPTION_IF_NULL(parallel_group_ranks);
for (auto &rank : group.second) {
parallel_group_ranks->add_dim(rank);
}
}
std::fstream output(group_info_save_file_, std::ios::out | std::ios::trunc | std::ios::binary);
if (!parallel_group_map.SerializeToOstream(&output)) {
MS_LOG(ERROR) << "Save strategy file failed";
return FAILED;
}
output.close();
return SUCCESS;
}
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -32,6 +32,7 @@ namespace parallel {
using StrategyMap = std::unordered_map<std::string, StrategyPtr>; using StrategyMap = std::unordered_map<std::string, StrategyPtr>;
using TensorInfoMap = std::unordered_map<std::string, TensorInfo>; using TensorInfoMap = std::unordered_map<std::string, TensorInfo>;
using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>; using ManualShapeMap = std::unordered_map<std::string, std::vector<std::pair<int64_t, int64_t>>>;
using GroupInfoMap = std::vector<std::pair<std::string, std::vector<uint32_t>>>;
class StrategyCheckpoint { class StrategyCheckpoint {
public: public:
StrategyCheckpoint() { StrategyCheckpoint() {
@ -40,11 +41,16 @@ class StrategyCheckpoint {
load_checkpoint_on_ = false; load_checkpoint_on_ = false;
save_file_ = ""; save_file_ = "";
save_checkpoint_on_ = false; save_checkpoint_on_ = false;
group_info_save_file_ = "";
group_info_save_on_ = false;
} }
~StrategyCheckpoint() = default; ~StrategyCheckpoint() = default;
Status Load(StrategyMap *strategy_map); Status Load(StrategyMap *strategy_map);
Status LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map);
Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map); Status Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, ManualShapeMap *manual_shape_map);
Status SaveGroupInfo(const GroupInfoMap &group_info_map);
bool group_info_save_on() const { return group_info_save_on_; }
static StrategyCheckpoint &GetInstance(); static StrategyCheckpoint &GetInstance();
bool LoadCheckPointOn() const { return load_checkpoint_on_; } bool LoadCheckPointOn() const { return load_checkpoint_on_; }
@ -57,6 +63,8 @@ class StrategyCheckpoint {
bool save_checkpoint_on_; bool save_checkpoint_on_;
bool CheckPointExit(const std::string path) const; bool CheckPointExit(const std::string path) const;
int64_t current_stage_; int64_t current_stage_;
std::string group_info_save_file_;
bool group_info_save_on_;
}; };
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -157,6 +157,7 @@ PYBIND11_MODULE(_c_expression, m) {
"Set strategy checkpoint save file.") "Set strategy checkpoint save file.")
.def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.") .def("get_strategy_ckpt_load_file", &ParallelContext::strategy_ckpt_load_file, "Get strategy checkpoint load file.")
.def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.") .def("get_strategy_ckpt_save_file", &ParallelContext::strategy_ckpt_save_file, "Get strategy checkpoint save file.")
.def("set_group_ckpt_save_file", &ParallelContext::set_group_ckpt_save_file, "Set group checkpoint save file.")
.def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num, .def("set_pipeline_stage_split_num", &ParallelContext::set_pipeline_stage_split_num,
"Set pipeline stage split num.") "Set pipeline stage split num.")
.def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.") .def("get_pipeline_stage_split_num", &ParallelContext::pipeline_stage_split_num, "Get pipeline stage split num.")

@ -61,6 +61,19 @@ message ParallelLayoutItem {
required ParallelLayouts parallel_layouts = 2; required ParallelLayouts parallel_layouts = 2;
} }
message ParallelGroupRanks {
repeated uint32 dim = 1;
}
message ParallelGroupItem {
required string group_name = 1;
required ParallelGroupRanks parallel_group_ranks = 2;
}
message ParallelGroupMap {
repeated ParallelGroupItem parallel_group_item = 1;
}
message ParallelStrategyMap { message ParallelStrategyMap {
required uint32 current_stage = 1; required uint32 current_stage = 1;
repeated ParallelStrategyItem parallel_strategy_item = 2; repeated ParallelStrategyItem parallel_strategy_item = 2;

@ -283,6 +283,15 @@ class _AutoParallelContext:
self.check_context_handle() self.check_context_handle()
return self._context_handle.get_strategy_ckpt_save_file() return self._context_handle.get_strategy_ckpt_save_file()
def set_group_ckpt_save_file(self, group_ckpt_save_file):
"""Set group checkpoint save path."""
self.check_context_handle()
import os
dir_path = os.path.dirname(group_ckpt_save_file)
if dir_path and not os.path.exists(dir_path):
os.makedirs(dir_path)
self._context_handle.set_group_ckpt_save_file(group_ckpt_save_file)
def get_parameter_broadcast_is_set(self): def get_parameter_broadcast_is_set(self):
"""Get parameter broadcast is set or not.""" """Get parameter broadcast is set or not."""
self.check_context_handle() self.check_context_handle()
@ -505,6 +514,7 @@ _set_auto_parallel_context_func_map = {
"parameter_broadcast": auto_parallel_context().set_parameter_broadcast, "parameter_broadcast": auto_parallel_context().set_parameter_broadcast,
"strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file, "strategy_ckpt_load_file": auto_parallel_context().set_strategy_ckpt_load_file,
"strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file, "strategy_ckpt_save_file": auto_parallel_context().set_strategy_ckpt_save_file,
"group_ckpt_save_file": auto_parallel_context().set_group_ckpt_save_file,
"full_batch": auto_parallel_context().set_full_batch, "full_batch": auto_parallel_context().set_full_batch,
"enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer, "enable_parallel_optimizer": auto_parallel_context().set_enable_parallel_optimizer,
"grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step, "grad_accumulation_step": auto_parallel_context().set_grad_accumulation_step,
@ -533,7 +543,7 @@ _get_auto_parallel_context_func_map = {
loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str, loss_repeated_mean=bool, parallel_mode=str, auto_parallel_search_mode=str,
parameter_broadcast=bool, strategy_ckpt_load_file=str, parameter_broadcast=bool, strategy_ckpt_load_file=str,
strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool, strategy_ckpt_save_file=str, full_batch=bool, enable_parallel_optimizer=bool,
grad_accumulation_step=int, all_reduce_fusion_config=list) grad_accumulation_step=int, all_reduce_fusion_config=list, group_ckpt_save_file=str)
def _set_auto_parallel_context(**kwargs): def _set_auto_parallel_context(**kwargs):
""" """
@ -574,6 +584,7 @@ def _set_auto_parallel_context(**kwargs):
broadcast. Default: False. broadcast. Default: False.
strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: '' strategy_ckpt_load_file (str): The path to load parallel strategy checkpoint. Default: ''
strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: '' strategy_ckpt_save_file (str): The path to save parallel strategy checkpoint. Default: ''
group_ckpt_save_file (str): The path to save parallel group checkpoint. Default: ''
full_batch (bool): Whether to load the whole batch on each device. Default: False. full_batch (bool): Whether to load the whole batch on each device. Default: False.
enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False. enable_parallel_optimizer (bool): Enable using optimizer segmentation or not. Default: False.
all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices. all_reduce_fusion_config (list): Set allreduce fusion strategy by parameters indices.

@ -31,5 +31,9 @@ Status StrategyCheckpoint::Load(StrategyMap* strategy_map) { return SUCCESS; }
Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map, Status StrategyCheckpoint::Save(const StrategyMap &strategy_map, const TensorInfoMap &tensor_info_map,
ManualShapeMap *manual_shape_map) { return SUCCESS; } ManualShapeMap *manual_shape_map) { return SUCCESS; }
Status StrategyCheckpoint::LoadGroupInfo(const std::string &file, GroupInfoMap *group_info_map) { return SUCCESS; }
Status StrategyCheckpoint::SaveGroupInfo(const GroupInfoMap &group_info_map) { return SUCCESS; }
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore

@ -75,7 +75,8 @@ def test_six_matmul_save():
return out return out
reset_auto_parallel_context() reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt") set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_save_file="./strategy_stage1.ckpt",
group_ckpt_save_file="./group_stage1.ckpt")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy2 = ((1, 8), (8, 1)) strategy2 = ((1, 8), (8, 1))
strategy3 = ((2, 2), (2, 2)) strategy3 = ((2, 2), (2, 2))
@ -137,7 +138,8 @@ def test_six_matmul_load():
return out return out
reset_auto_parallel_context() reset_auto_parallel_context()
set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt") set_auto_parallel_context(device_num=8, global_rank=0, strategy_ckpt_load_file="./strategy_stage1.ckpt",
group_ckpt_save_file="./group_stage1.ckpt")
strategy1 = ((8, 1), (1, 1)) strategy1 = ((8, 1), (1, 1))
strategy3 = ((8, 1), (1, 1)) strategy3 = ((8, 1), (1, 1))
strategy4 = ((8, 1), (1, 1)) strategy4 = ((8, 1), (1, 1))

Loading…
Cancel
Save