|
|
|
@ -956,18 +956,7 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Only used for InsertMirrorOps
|
|
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
} else if (node->isa<Parameter>()) {
|
|
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
|
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
} else {
|
|
|
|
|
return std::make_pair(node, false);
|
|
|
|
|
}
|
|
|
|
|
} else if (node->isa<ValueNode>()) {
|
|
|
|
|
static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (IsValueNode<RefKey>(node)) {
|
|
|
|
|
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
|
|
|
|
|
if (param_v.size() != 1) {
|
|
|
|
@ -977,12 +966,30 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|
|
|
|
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
|
|
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
|
|
|
|
return std::make_pair(nullptr, true);
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(node, true);
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// Only used for InsertMirrorOps
|
|
|
|
|
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
|
|
|
|
|
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
auto param_ptr = node->user_data<parallel::TensorLayout>();
|
|
|
|
|
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(node, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (node->isa<ValueNode>()) {
|
|
|
|
|
return FindParameterByValueNode(node, func_graph);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
@ -992,13 +999,16 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|
|
|
|
}
|
|
|
|
|
return FindParameter(cnode->input(index), func_graph);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
return std::make_pair(node, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (IsParallelCareNode(cnode)) {
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
} else {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
|
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
|
|
|
@ -1012,9 +1022,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
|
|
|
|
|
}
|
|
|
|
|
return FindParameter(cnode->input(index), func_graph);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return std::make_pair(nullptr, false);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1101,6 +1108,25 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
|
|
|
|
|
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
|
|
|
|
|
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
|
|
|
|
|
MS_LOG(INFO) << "Input is ValueList, skip it.";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((node->inputs().size() == 2) &&
|
|
|
|
|
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
|
|
|
|
|
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (mirror_ops.size() != node_size - 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
|
|
|
|
<< node_size - 1;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
size_t node_size = node->inputs().size();
|
|
|
|
@ -1113,21 +1139,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
node_size--;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
|
|
|
|
|
MS_LOG(INFO) << "Input is ValueList, skip it.";
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if ((node->inputs().size() == 2) &&
|
|
|
|
|
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
|
|
|
|
|
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
|
|
|
|
|
if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) {
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (mirror_ops.size() != node_size - 1) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
|
|
|
|
|
<< node_size - 1;
|
|
|
|
|
}
|
|
|
|
|
for (size_t index = 1; index < node_size; ++index) {
|
|
|
|
|
OperatorVector backward_op = mirror_ops[index - 1];
|
|
|
|
|
if (backward_op.empty()) {
|
|
|
|
@ -1181,7 +1197,8 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
// pipeline mirror would not be set, which should be supported later
|
|
|
|
|
AddCommOpFusionType(comm_op, param_node_pair.first);
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto &op : backward_op) {
|
|
|
|
|
AnfNodePtr pre_node = node->input(index);
|
|
|
|
|
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
|
|
|
|
@ -1192,7 +1209,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
|
|
|
|
|
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
|
|
|
|
@ -1849,14 +1865,30 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
|
|
|
|
|
strategyPtr->ResetInputs(strategys);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
static bool CheckExtractInfomation(const CNodePtr &cnode) {
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (!IsParallelCareNode(cnode)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) {
|
|
|
|
|
// load strategy map from checkpoint
|
|
|
|
|
StrategyMap stra_map;
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) {
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) {
|
|
|
|
|
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
|
|
|
|
|
(StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
vector<std::string> last_forward_node_ids;
|
|
|
|
|
if (!is_training) {
|
|
|
|
|
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
|
|
|
|
@ -1865,34 +1897,32 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|
|
|
|
|
|
|
|
|
for (auto &node : all_nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
if (!CheckExtractInfomation(cnode)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
SetVirtualDatasetStrategy(cnode);
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto attrs = prim->attrs();
|
|
|
|
|
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
|
|
|
|
|
if (IsParallelCareNode(cnode)) {
|
|
|
|
|
|
|
|
|
|
std::vector<Shapes> shape_list = ExtractShape(cnode);
|
|
|
|
|
if (shape_list.empty()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
|
|
|
|
|
}
|
|
|
|
|
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
|
|
|
|
|
if (operator_ == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed";
|
|
|
|
|
}
|
|
|
|
|
MS_EXCEPTION_IF_NULL(operator_);
|
|
|
|
|
|
|
|
|
|
auto &inputs = cnode->inputs();
|
|
|
|
|
std::vector<ValuePtr> input_value;
|
|
|
|
|
for (size_t index = 1; index < inputs.size(); ++index) {
|
|
|
|
|
if (inputs[index]->isa<ValueNode>()) {
|
|
|
|
|
input_value.push_back(GetValueNode(inputs[index]));
|
|
|
|
|
} else {
|
|
|
|
|
input_value.emplace_back(nullptr);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
input_value.emplace_back(nullptr);
|
|
|
|
|
}
|
|
|
|
|
StrategyPtr strategyPtr = nullptr;
|
|
|
|
|
(*operator_).set_input_value(input_value);
|
|
|
|
@ -1923,7 +1953,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|
|
|
|
} else {
|
|
|
|
|
strategyPtr = stra_map[strategy_key_name];
|
|
|
|
|
}
|
|
|
|
|
if (strategyPtr != nullptr) {
|
|
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(strategyPtr);
|
|
|
|
|
if (is_last_nodes && full_batch) {
|
|
|
|
|
SetLastNodeStrategy(strategyPtr);
|
|
|
|
|
}
|
|
|
|
@ -1931,10 +1962,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
|
|
|
|
}
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(operator_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1994,9 +2021,9 @@ std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, si
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(distribute_operator);
|
|
|
|
|
if (distribute_operator->outputs_tensor_info().size() < output_index) {
|
|
|
|
|
if (distribute_operator->outputs_tensor_info().size() <= output_index) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size()
|
|
|
|
|
<< ", must be less than output_index " << output_index;
|
|
|
|
|
<< ", must be greater than output_index " << output_index;
|
|
|
|
|
}
|
|
|
|
|
TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
|
|
|
|
|
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();
|
|
|
|
|