modify get output layout

pull/12416/head
yangzhenzhang 4 years ago
parent 28cbab85ed
commit 70aa0dc5e2

@ -956,18 +956,7 @@ void InsertVirtualDivOp(const VirtualDivOp &virtual_div_op, const CNodePtr &node
} }
} }
// Only used for InsertMirrorOps static std::pair<AnfNodePtr, bool> FindParameterByValueNode(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false);
} else if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
} else {
return std::make_pair(node, false);
}
} else if (node->isa<ValueNode>()) {
if (IsValueNode<RefKey>(node)) { if (IsValueNode<RefKey>(node)) {
std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph); std::vector<AnfNodePtr> param_v = FindParameterByRefKeyNode(node, func_graph);
if (param_v.size() != 1) { if (param_v.size() != 1) {
@ -977,12 +966,30 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>(); auto param_ptr = param_v[0]->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) { if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, true); return std::make_pair(nullptr, true);
} else { }
return std::make_pair(node, true); return std::make_pair(node, true);
} }
return std::make_pair(nullptr, false);
} }
// Only used for InsertMirrorOps
std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGraphPtr &func_graph) {
if (!node->isa<Parameter>() && !node->isa<CNode>() && !node->isa<ValueNode>()) {
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} else { }
if (node->isa<Parameter>()) {
auto param_ptr = node->user_data<parallel::TensorLayout>();
if (param_ptr != nullptr && !param_ptr->opt_shard_group().empty()) {
return std::make_pair(nullptr, false);
}
return std::make_pair(node, false);
}
if (node->isa<ValueNode>()) {
return FindParameterByValueNode(node, func_graph);
}
CNodePtr cnode = node->cast<CNodePtr>(); CNodePtr cnode = node->cast<CNodePtr>();
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
if (!IsValueNode<Primitive>(cnode->input(0))) { if (!IsValueNode<Primitive>(cnode->input(0))) {
@ -992,13 +999,16 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
} }
return FindParameter(cnode->input(index), func_graph); return FindParameter(cnode->input(index), func_graph);
} }
} else { }
if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) { if (IsSomePrimitive(cnode, RECEIVE) && !cnode->has_user_data<OperatorInfo>()) {
return std::make_pair(node, false); return std::make_pair(node, false);
} }
if (IsParallelCareNode(cnode)) { if (IsParallelCareNode(cnode)) {
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} else { }
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
MS_EXCEPTION_IF_NULL(prim_anf_node); MS_EXCEPTION_IF_NULL(prim_anf_node);
for (size_t index = 0; index < cnode->inputs().size(); ++index) { for (size_t index = 0; index < cnode->inputs().size(); ++index) {
@ -1012,9 +1022,6 @@ std::pair<AnfNodePtr, bool> FindParameter(const AnfNodePtr &node, const FuncGrap
} }
return FindParameter(cnode->input(index), func_graph); return FindParameter(cnode->input(index), func_graph);
} }
}
}
}
return std::make_pair(nullptr, false); return std::make_pair(nullptr, false);
} }
@ -1101,6 +1108,25 @@ static void AddCommOpFusionType(const CNodePtr &comm_node, const AnfNodePtr &par
MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type; MS_LOG(INFO) << "Set comm fusion:" << param->param_info()->name() << "'s fusion type is " << fusion_type;
} }
static bool CheckInsertMirrorOps(const MirrorOps &mirror_ops, const CNodePtr &node, size_t node_size) {
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
return false;
}
if ((node->inputs().size() == 2) &&
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
return false;
}
if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< node_size - 1;
}
return true;
}
void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) { void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, const CNodePtr &node) {
MS_EXCEPTION_IF_NULL(node); MS_EXCEPTION_IF_NULL(node);
size_t node_size = node->inputs().size(); size_t node_size = node->inputs().size();
@ -1113,21 +1139,11 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
node_size--; node_size--;
} }
} }
if ((node->inputs().size() == 2) && (IsValueNode<ValueSequeue>(node->input(1)))) {
MS_LOG(INFO) << "Input is ValueList, skip it.";
return;
}
if ((node->inputs().size() == 2) && if (!CheckInsertMirrorOps(mirror_ops, node, node_size)) {
(AnfNodeIsPrimitive(node->input(1), MAKE_TUPLE) || AnfNodeIsPrimitive(node->input(1), MAKE_LIST))) {
MS_LOG(INFO) << "The mirror for " << GetPrimName(node) << " has handle by make_tuple node";
return; return;
} }
if (mirror_ops.size() != node_size - 1) {
MS_LOG(EXCEPTION) << "Mirrorops's size is wrong! mirror_ops size is " << mirror_ops.size() << ", node_size is "
<< node_size - 1;
}
for (size_t index = 1; index < node_size; ++index) { for (size_t index = 1; index < node_size; ++index) {
OperatorVector backward_op = mirror_ops[index - 1]; OperatorVector backward_op = mirror_ops[index - 1];
if (backward_op.empty()) { if (backward_op.empty()) {
@ -1181,7 +1197,8 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
// pipeline mirror would not be set, which should be supported later // pipeline mirror would not be set, which should be supported later
AddCommOpFusionType(comm_op, param_node_pair.first); AddCommOpFusionType(comm_op, param_node_pair.first);
} }
} else { continue;
}
for (auto &op : backward_op) { for (auto &op : backward_op) {
AnfNodePtr pre_node = node->input(index); AnfNodePtr pre_node = node->input(index);
InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name); InsertMirrorNode(root, op, node, index, pre_node, func_graph, instance_name, param_name);
@ -1192,7 +1209,6 @@ void InsertMirrorOps(const FuncGraphPtr &root, const MirrorOps &mirror_ops, cons
} }
} }
} }
}
void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node, void BackwardCommunication(const FuncGraphPtr &root, const OperatorInfoPtr &distribute_operator, const CNodePtr &node,
const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) { const std::vector<std::pair<CNodePtr, LossNodeInfo>> &sens_loss_pairs) {
@ -1849,14 +1865,30 @@ void SetLastNodeStrategy(const StrategyPtr strategyPtr) {
strategyPtr->ResetInputs(strategys); strategyPtr->ResetInputs(strategys);
} }
static bool CheckExtractInfomation(const CNodePtr &cnode) {
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
return false;
}
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if ((prim->name() == MAKE_TUPLE) || (prim->name() == MAKE_LIST) || (prim->name() == RECEIVE)) {
return false;
}
if (!IsParallelCareNode(cnode)) {
return false;
}
return true;
}
void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) { void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_training) {
// load strategy map from checkpoint // load strategy map from checkpoint
StrategyMap stra_map; StrategyMap stra_map;
if (StrategyCheckpoint::GetInstance().LoadCheckPointOn()) { if (StrategyCheckpoint::GetInstance().LoadCheckPointOn() &&
if (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS) { (StrategyCheckpoint::GetInstance().Load(&stra_map) != SUCCESS)) {
MS_LOG(EXCEPTION) << "Load strategy checkpoint failed"; MS_LOG(EXCEPTION) << "Load strategy checkpoint failed";
} }
}
vector<std::string> last_forward_node_ids; vector<std::string> last_forward_node_ids;
if (!is_training) { if (!is_training) {
FindLastNodesUniqueId(all_nodes, &last_forward_node_ids); FindLastNodesUniqueId(all_nodes, &last_forward_node_ids);
@ -1865,34 +1897,32 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
for (auto &node : all_nodes) { for (auto &node : all_nodes) {
auto cnode = node->cast<CNodePtr>(); auto cnode = node->cast<CNodePtr>();
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) { if (!CheckExtractInfomation(cnode)) {
continue; continue;
} }
SetVirtualDatasetStrategy(cnode); SetVirtualDatasetStrategy(cnode);
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>(); ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node); PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
if (prim->name() == MAKE_TUPLE || prim->name() == MAKE_LIST || prim->name() == RECEIVE) {
continue;
}
auto attrs = prim->attrs(); auto attrs = prim->attrs();
MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name(); MS_LOG(INFO) << "extract information: node: " << node->ToString() << " prim " << prim->name();
if (IsParallelCareNode(cnode)) {
std::vector<Shapes> shape_list = ExtractShape(cnode); std::vector<Shapes> shape_list = ExtractShape(cnode);
if (shape_list.empty()) { if (shape_list.empty()) {
MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape"; MS_LOG(EXCEPTION) << "Failure:node " << node->ToString() << " failed to extract shape";
} }
OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list); OperatorInfoPtr operator_ = OperatorInstance(prim, attrs, shape_list);
if (operator_ == nullptr) { MS_EXCEPTION_IF_NULL(operator_);
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->name() << " OperatorInstance failed";
}
auto &inputs = cnode->inputs(); auto &inputs = cnode->inputs();
std::vector<ValuePtr> input_value; std::vector<ValuePtr> input_value;
for (size_t index = 1; index < inputs.size(); ++index) { for (size_t index = 1; index < inputs.size(); ++index) {
if (inputs[index]->isa<ValueNode>()) { if (inputs[index]->isa<ValueNode>()) {
input_value.push_back(GetValueNode(inputs[index])); input_value.push_back(GetValueNode(inputs[index]));
} else { continue;
input_value.emplace_back(nullptr);
} }
input_value.emplace_back(nullptr);
} }
StrategyPtr strategyPtr = nullptr; StrategyPtr strategyPtr = nullptr;
(*operator_).set_input_value(input_value); (*operator_).set_input_value(input_value);
@ -1923,7 +1953,8 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
} else { } else {
strategyPtr = stra_map[strategy_key_name]; strategyPtr = stra_map[strategy_key_name];
} }
if (strategyPtr != nullptr) {
MS_EXCEPTION_IF_NULL(strategyPtr);
if (is_last_nodes && full_batch) { if (is_last_nodes && full_batch) {
SetLastNodeStrategy(strategyPtr); SetLastNodeStrategy(strategyPtr);
} }
@ -1931,10 +1962,6 @@ void ExtractInformation(const std::vector<AnfNodePtr> &all_nodes, bool is_traini
MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed"; MS_LOG(EXCEPTION) << "Failure:operator " << prim->name() << " init failed";
} }
cnode->set_user_data<OperatorInfo>(operator_); cnode->set_user_data<OperatorInfo>(operator_);
} else {
MS_LOG(EXCEPTION) << "ERROR:strategy_ptr is nullptr";
}
}
} }
} }
@ -1994,9 +2021,9 @@ std::shared_ptr<TensorLayout> GetOutputLayoutFromCNode(const CNodePtr &cnode, si
MS_EXCEPTION_IF_NULL(cnode); MS_EXCEPTION_IF_NULL(cnode);
OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode); OperatorInfoPtr distribute_operator = GetDistributeOperator(cnode);
MS_EXCEPTION_IF_NULL(distribute_operator); MS_EXCEPTION_IF_NULL(distribute_operator);
if (distribute_operator->outputs_tensor_info().size() < output_index) { if (distribute_operator->outputs_tensor_info().size() <= output_index) {
MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size() MS_LOG(EXCEPTION) << "outputs_tensor_info size is " << distribute_operator->inputs_tensor_info().size()
<< ", must be less than output_index " << output_index; << ", must be greater than output_index " << output_index;
} }
TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index]; TensorInfo tensorinfo_out = distribute_operator->outputs_tensor_info()[output_index];
TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout(); TensorLayout tensorlayout_out = tensorinfo_out.tensor_layout();

Loading…
Cancel
Save