|
|
|
@ -272,7 +272,7 @@ OperatorInfoPtr GetDistributeOperator(const CNodePtr &node) {
|
|
|
|
|
if (!IsParallelCareNode(node)) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "GetDistributeOperator:distribute_operator is nullptr";
|
|
|
|
|
}
|
|
|
|
@ -409,7 +409,7 @@ bool IsParallelCareNode(const CNodePtr &cnode) {
|
|
|
|
|
if (prim->name() == GET_NEXT) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
if ((prim->name() == CAST) && !cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if ((prim->name() == CAST) && !cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -446,7 +446,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|
|
|
|
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(use_cnode) && use_cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (IsParallelCareNode(use_cnode) && use_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
Redistribution(node_pair, distribute_operator, insert_node_new, node_pair.second, tensor_redistribution,
|
|
|
|
|
pre_node);
|
|
|
|
|
} else {
|
|
|
|
@ -459,7 +459,7 @@ void StepRedistribution(const CNodePtr &node, const OperatorInfoPtr &distribute_
|
|
|
|
|
void SplitTensor(const AnfNodePtr &node, const CNodePtr &next_node, int index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(next_node);
|
|
|
|
|
OperatorInfoPtr op_info = next_node->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr op_info = next_node->user_data<OperatorInfo>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(op_info);
|
|
|
|
|
|
|
|
|
|
// If the shape of tensor is [] or [1], no need to split it.
|
|
|
|
@ -584,7 +584,7 @@ void ReplaceOneOp(const Operator &replace_op, const CNodePtr &node) {
|
|
|
|
|
|
|
|
|
|
void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|
|
|
|
// step1:get graph manager distribute_operator
|
|
|
|
|
OperatorInfoPtr distribute_operator = node->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr distribute_operator = node->user_data<OperatorInfo>();
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:AddNode error since distribute_operator is nullptr";
|
|
|
|
|
}
|
|
|
|
@ -622,7 +622,7 @@ void StepReplaceOp(OperatorVector replace_op, const CNodePtr &node) {
|
|
|
|
|
(void)prim->SetAttrs(attrs);
|
|
|
|
|
}
|
|
|
|
|
if (index == replace_op.size() - 1) {
|
|
|
|
|
replace_node->SetUserData<OperatorInfo>(node->GetUserData<OperatorInfo>());
|
|
|
|
|
replace_node->set_user_data<OperatorInfo>(node->user_data<OperatorInfo>());
|
|
|
|
|
}
|
|
|
|
|
replace_node->set_in_forward_flag(true);
|
|
|
|
|
replace_input[0]->set_scope(scope);
|
|
|
|
@ -702,7 +702,7 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr &loss_node) {
|
|
|
|
|
auto pre_cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
auto pre_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
if (pre_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (pre_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
pre_node = pre_cnode->input(1);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
@ -1198,7 +1198,7 @@ std::pair<AnfNodePtr, int> FindParallelCareNode(const AnfNodePtr &node) {
|
|
|
|
|
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
return node_pair;
|
|
|
|
|
} else if (FindParallelCareNode(node_pair.first).first != nullptr) {
|
|
|
|
|
return FindParallelCareNode(node_pair.first);
|
|
|
|
@ -1248,7 +1248,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|
|
|
|
MS_LOG(DEBUG) << "SetParallelShape " << parameter->ToString() << " shape " << parameter->Shape()->ToString();
|
|
|
|
|
CNodePtr cnode = res.first->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
OperatorInfoPtr distribute_operator = cnode->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr distribute_operator = cnode->user_data<OperatorInfo>();
|
|
|
|
|
if (distribute_operator == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:node " << cnode->ToString() << " 's OperatorInfoPtr is nullptr";
|
|
|
|
|
}
|
|
|
|
@ -1271,7 +1271,7 @@ void SetParallelShape(const AnfNodePtr ¶meter, const std::pair<AnfNodePtr, i
|
|
|
|
|
TensorLayout tensor_layout = tensorinfo_in.tensor_layout();
|
|
|
|
|
ParameterPtr parameter_ptr = parameter->cast<ParameterPtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(parameter_ptr);
|
|
|
|
|
parameter_ptr->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
|
|
|
|
parameter_ptr->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(tensor_layout));
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void CoverSliceShape(const FuncGraphPtr &root) {
|
|
|
|
@ -1359,7 +1359,7 @@ void SetClonedTensorShapeForOptimizer(const FuncGraphPtr &root) {
|
|
|
|
|
|
|
|
|
|
if (found_be_cloned_parameter) {
|
|
|
|
|
// set the shape and tensor layout for cloned parameter
|
|
|
|
|
cloned_parameter->SetUserData<TensorLayout>(cloned_from_parameter->GetUserData<TensorLayout>());
|
|
|
|
|
cloned_parameter->set_user_data<TensorLayout>(cloned_from_parameter->user_data<TensorLayout>());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_parameter_node->abstract());
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cloned_from_node->abstract());
|
|
|
|
|
auto cloned_abstract = cloned_parameter_node->abstract()->Clone();
|
|
|
|
@ -1454,7 +1454,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
(*operator_).set_outputs_dtype(cnode->Type());
|
|
|
|
|
(*operator_).set_cnode(cnode);
|
|
|
|
|
if (prim->name() == RESHAPE) {
|
|
|
|
|
cnode->SetUserData<OperatorInfo>(operator_);
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(operator_);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
// load strategy checkpoint
|
|
|
|
@ -1489,7 +1489,7 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
if (operator_->Init(strategyPtr) == FAILED) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
|
|
|
|
|
}
|
|
|
|
|
cnode->SetUserData<OperatorInfo>(operator_);
|
|
|
|
|
cnode->set_user_data<OperatorInfo>(operator_);
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
|
|
|
|
|
}
|
|
|
|
@ -1532,13 +1532,13 @@ std::shared_ptr<TensorLayout> FindNextLayout(const CNodePtr &cnode) {
|
|
|
|
|
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(use_apply) && use_apply->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (IsParallelCareNode(use_apply) && use_apply->has_user_data<OperatorInfo>()) {
|
|
|
|
|
MS_LOG(INFO) << "FindNextLayout success prim " << node_prim->name();
|
|
|
|
|
auto layout = GetInputLayoutFromCNode(node_pair);
|
|
|
|
|
return std::make_shared<TensorLayout>(layout);
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "FindNextLayout failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
|
|
|
|
<< " " << use_apply->HasUserData<OperatorInfo>();
|
|
|
|
|
<< " " << use_apply->has_user_data<OperatorInfo>();
|
|
|
|
|
|
|
|
|
|
auto layout_ptr = FindNextLayout(use_apply);
|
|
|
|
|
if (layout_ptr) {
|
|
|
|
@ -1570,7 +1570,7 @@ std::shared_ptr<TensorLayout> FindPrevParallelCareNodeLayout(const AnfNodePtr &n
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
auto layout_ptr = GetOutputLayoutFromCNode(cnode, output_index);
|
|
|
|
|
if (!layout_ptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
|
|
|
|
@ -1614,7 +1614,7 @@ std::shared_ptr<TensorLayout> FindPrevLayout(const AnfNodePtr &node) {
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return nullptr;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (IsParallelCareNode(cnode) && cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
auto layout_ptr = GetOutputLayoutFromCNode(cnode, 0);
|
|
|
|
|
if (!layout_ptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:GetLayoutFromCNode failed";
|
|
|
|
@ -1654,12 +1654,12 @@ void ReshapeInit(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
if (!IsParallelCareNode(cnode) || !cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (!IsParallelCareNode(cnode) || !cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
|
|
|
|
|
if (operator_info == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
|
|
|
|
|
}
|
|
|
|
@ -1704,7 +1704,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) {
|
|
|
|
|
|
|
|
|
|
auto current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
|
// return -> cast
|
|
|
|
|
if (current_prim->name() == CAST && !pre_cnode->HasUserData<OperatorInfo>()) {
|
|
|
|
|
if (current_prim->name() == CAST && !pre_cnode->has_user_data<OperatorInfo>()) {
|
|
|
|
|
pre_cnode = pre_cnode->input(1)->cast<CNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(pre_cnode);
|
|
|
|
|
current_prim = GetValueNode<PrimitivePtr>(pre_cnode->input(0));
|
|
|
|
@ -1761,7 +1761,7 @@ TensorLayouts GetLossNodeGradOutputLayout(const CNodePtr &loss_cnode) {
|
|
|
|
|
return ret;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
OperatorInfoPtr operator_info = loss_cnode->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr operator_info = loss_cnode->user_data<OperatorInfo>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(operator_info);
|
|
|
|
|
TensorInfo loss_grad_tensor_info;
|
|
|
|
|
size_t op_output_size = operator_info->outputs_tensor_info().size();
|
|
|
|
@ -1799,7 +1799,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
|
|
|
|
|
if (sens_tensor_node->isa<Parameter>()) {
|
|
|
|
|
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
|
|
|
|
MS_LOG(DEBUG) << "loss layout " << loss_grad_layout.ToString();
|
|
|
|
|
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
|
sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(INFO) << "The shape of sens is " << ShapeToString(sens_shape) << ", no need to split sens";
|
|
|
|
|
return;
|
|
|
|
@ -1824,7 +1824,7 @@ void SplitSens(const CNodePtr &grad_sens_node, const TensorLayout &loss_grad_lay
|
|
|
|
|
cloned_abstract->set_shape(parallel_shape);
|
|
|
|
|
sens_tensor_node->set_abstract(cloned_abstract);
|
|
|
|
|
auto sens_tensor_param = sens_tensor_node->cast<ParameterPtr>();
|
|
|
|
|
sens_tensor_param->SetUserData<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
|
sens_tensor_param->set_user_data<TensorLayout>(std::make_shared<TensorLayout>(loss_grad_layout));
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(EXCEPTION) << "The type of sens node is not Tensor or Parameter, it is unsupported now.";
|
|
|
|
@ -2131,7 +2131,7 @@ void CheckpointStrategy(const FuncGraphPtr &func_graph) {
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(cnode->input(0));
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->GetUserData<OperatorInfo>();
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->user_data<OperatorInfo>();
|
|
|
|
|
if (operator_info) {
|
|
|
|
|
if (operator_info->name().find(RESHAPEINFO) != std::string::npos) {
|
|
|
|
|
continue;
|
|
|
|
|