|
|
|
@ -326,7 +326,7 @@ void ForwardCommunication(OperatorVector forward_op, const CNodePtr &node) {
|
|
|
|
|
std::string instance_name_base = FORWARD_OP;
|
|
|
|
|
std::string instance_name = instance_name_base + "_" + CreateInstanceName(node, index);
|
|
|
|
|
std::vector<AnfNodePtr> forward_input = CreateInput(forward_op[index], node_to_insert, instance_name);
|
|
|
|
|
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to creat anfnode
|
|
|
|
|
CNodePtr forward_node = func_graph->NewCNode(forward_input); // using NewCNode to create anfnode
|
|
|
|
|
MS_EXCEPTION_IF_NULL(forward_node);
|
|
|
|
|
ScopePtr scope = node->scope();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(scope);
|
|
|
|
@ -371,10 +371,10 @@ void InsertRedistribution(const RedistributionOpListPtr &redistribution_oplist_p
|
|
|
|
|
if (pos >= SizeToLong(node->inputs().size())) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "InsertRedistribution:pos can't be larger than node's inputs'size";
|
|
|
|
|
}
|
|
|
|
|
// Creat new node
|
|
|
|
|
// Create new node
|
|
|
|
|
AnfNodePtr target_node = node->input(LongToSize(pos));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(target_node);
|
|
|
|
|
// Creat instance_name
|
|
|
|
|
// Create instance_name
|
|
|
|
|
auto op = (redistribution_oplist_ptr->first)[index];
|
|
|
|
|
std::string op_name = (redistribution_oplist_ptr->first)[index].first;
|
|
|
|
|
std::string instance_name_base = REDISTRIBUTION_OP;
|
|
|
|
@ -400,7 +400,7 @@ void InsertGetTensorSliceOp(const Operator &op, const CNodePtr &node, const Func
|
|
|
|
|
MS_LOG(EXCEPTION) << "InsertGetTensorSliceOp: pos can't be larger than node's inputs'size, the instance name is "
|
|
|
|
|
<< instance_name;
|
|
|
|
|
}
|
|
|
|
|
// Creat new node
|
|
|
|
|
// Create new node
|
|
|
|
|
AnfNodePtr pre_node = node->input(LongToSize(pos));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_node);
|
|
|
|
|
InsertNode(op, node, LongToSize(pos), pre_node, func_graph, instance_name);
|
|
|
|
@ -595,7 +595,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|
|
|
|
CNodePtr insert_node_new;
|
|
|
|
|
|
|
|
|
|
if (AnfNodeIsPrimitive(node, MAKE_TUPLE) || AnfNodeIsPrimitive(node, MAKE_LIST)) {
|
|
|
|
|
MS_LOG(INFO) << "No need to insert redistribution op betweend make_tuple node and the next node";
|
|
|
|
|
MS_LOG(INFO) << "No need to insert redistribution op between make_tuple node and the next node";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
if (IsValueNode<Primitive>(node->input(0))) {
|
|
|
|
@ -883,10 +883,10 @@ void StepReplaceGraph(const ReplaceGraphPtr &replace_graph, const CNodePtr &node
|
|
|
|
|
if (manager == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:AddNode error since manager is nullptr";
|
|
|
|
|
}
|
|
|
|
|
// Sovle the input order
|
|
|
|
|
// Solve the input order
|
|
|
|
|
// For example input_node:{segment_sum:1, segment_sum:2, gahter:2}
|
|
|
|
|
// The Original code here will bind the all operations to the first inputs of theses operatos
|
|
|
|
|
// However, the segment_sum operation needs two inputs, To sovle this
|
|
|
|
|
// The Original code here will bind the all operations to the first inputs of these operatos
|
|
|
|
|
// However, the segment_sum operation needs two inputs, To solve this
|
|
|
|
|
// We maintain a dict to count the times of the same operations,
|
|
|
|
|
// and bind the inputs according to the times of the op appears.
|
|
|
|
|
static std::unordered_map<AnfNodePtr, int> input_map = {};
|
|
|
|
@ -1241,9 +1241,9 @@ OperatorInfoPtr OperatorInstanceByName(const std::string &name, const PrimitiveA
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
OperatorInfoPtr operator_ =
|
|
|
|
|
(OperatorInfoPtr)DynCreator::Instance().Creat(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
|
|
|
|
(OperatorInfoPtr)DynCreator::Instance().Create(distribute_opname, shape_list[0], shape_list[1], attrs, TOTAL_OPS);
|
|
|
|
|
if (operator_ == nullptr) {
|
|
|
|
|
MS_LOG(INFO) << "Creat " << name << " failed";
|
|
|
|
|
MS_LOG(INFO) << "Create " << name << " failed";
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
std::string origin_name = operator_->name();
|
|
|
|
@ -1261,7 +1261,7 @@ OperatorInfoPtr OperatorInstance(const PrimitivePtr &prim, const PrimitiveAttrs
|
|
|
|
|
if (IsInBatchParallelBlackList(prim)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Operator " << prim->name() << " is not supported yet in auto parallel mode.";
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Creat " << prim->name() << " failed, use batch parallel";
|
|
|
|
|
MS_LOG(INFO) << "Create " << prim->name() << " failed, use batch parallel";
|
|
|
|
|
operator_ = OperatorInstanceByName(BATCH_PARALLEL, attrs, shape_list);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(operator_);
|
|
|
|
|
}
|
|
|
|
@ -1351,7 +1351,7 @@ Shapes GetNodeShape(const AnfNodePtr &node) {
|
|
|
|
|
}
|
|
|
|
|
if (cnode->input(0)->isa<CNode>()) {
|
|
|
|
|
if (cnode->inputs().size() < 2) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is samller than 2";
|
|
|
|
|
MS_LOG(EXCEPTION) << "GetNodeShape: " << node->ToString() << " size is smaller than 2";
|
|
|
|
|
}
|
|
|
|
|
base_shape_ptr = cnode->input(1)->Shape();
|
|
|
|
|
}
|
|
|
|
@ -2546,7 +2546,7 @@ void ParallelCommunication(const FuncGraphPtr &root, const std::vector<AnfNodePt
|
|
|
|
|
bool has_backward = !sens_loss_pairs.empty();
|
|
|
|
|
// split sens must before inserting the operators.
|
|
|
|
|
for (auto &pair : sens_loss_pairs) {
|
|
|
|
|
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handel it.
|
|
|
|
|
// If the shape of grad-sens tensor is not [] or [1], use get tensor slice to handle it.
|
|
|
|
|
// If the type of sens node is not Tensor, it is unsupported now, do nothing default.
|
|
|
|
|
if (IsLastStage()) {
|
|
|
|
|
StepSplitSens(pair);
|
|
|
|
@ -2703,7 +2703,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
auto param_split_shapes = gatherv2_info->param_split_shapes();
|
|
|
|
|
auto index_offsets = gatherv2_info->index_offsets();
|
|
|
|
|
if (param_split_shapes.size() != index_offsets.size()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets lenght should be same.";
|
|
|
|
|
MS_LOG(EXCEPTION) << "In manual split, the param_split_shapes and index_offsets length should be same.";
|
|
|
|
|
}
|
|
|
|
|
std::vector<std::pair<int64_t, int64_t>> manual_shape;
|
|
|
|
|
for (int64_t i = 0; i < UlongToLong(param_split_shapes.size()); ++i) {
|
|
|
|
@ -2713,6 +2713,7 @@ void CheckpointStrategy(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Save(stra_map, tensor_info_map, &manual_shape_map) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
@ -3142,6 +3143,19 @@ void CheckParameterSplit(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool CreateGroupsByCkptFile(const std::string &file) {
|
|
|
|
|
GroupInfoMap group_info_map;
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadGroupInfo(file, &group_info_map) != SUCCESS) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (CreateGroups(group_info_map) != SUCCESS) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "Create groups by checkpoint file success";
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool IsUsedParameter(const FuncGraphPtr &graph, const AnfNodePtr ¶meter) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter);
|
|
|
|
@ -3290,6 +3304,12 @@ bool StepParallel(const FuncGraphPtr &root, const opt::OptimizerPtr &optimizer)
|
|
|
|
|
// ForwardCommunication BackwardCommunication TensorRedistribution
|
|
|
|
|
ParallelCommunication(root, all_nodes, manager);
|
|
|
|
|
|
|
|
|
|
auto group_info = g_device_manager->group_info();
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().group_info_save_on() &&
|
|
|
|
|
StrategyCheckpoint::GetInstance().SaveGroupInfo(group_info) != SUCCESS) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Save group info failed";
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DumpGraph(root, std::string(STEP_PARALLEL_END));
|
|
|
|
|
|
|
|
|
|
// step parallel only run once
|
|
|
|
|