|
|
|
@ -39,6 +39,7 @@
|
|
|
|
|
#include "parallel/auto_parallel/rec_core/rec_partition.h"
|
|
|
|
|
#include "parallel/context.h"
|
|
|
|
|
#include "parallel/ops_info/tmp_identity_info.h"
|
|
|
|
|
#include "parallel/ops_info/reshape_info.h"
|
|
|
|
|
#include "parallel/step_parallel.h"
|
|
|
|
|
#include "parallel/strategy_checkpoint/parallel_strategy_checkpoint.h"
|
|
|
|
|
#include "pipeline/parse/python_adapter.h"
|
|
|
|
@ -608,7 +609,8 @@ void ConstructCostGraphEdges(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
EdgePtr edge_ptr;
|
|
|
|
|
MS_LOG(INFO) << "Creating edge: " << edge_name;
|
|
|
|
|
|
|
|
|
|
bool follow_strategy = ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name());
|
|
|
|
|
bool follow_strategy = (prim->name() == RESHAPE) || (prev_prim->name() == RESHAPE) ||
|
|
|
|
|
(ELEMENTWISE_OP_STRA_FOLLOW && IsElementWiseOperator(prev_prim->name()));
|
|
|
|
|
if (follow_strategy) {
|
|
|
|
|
// Redistribution in not allowed on the edge.
|
|
|
|
|
// Elementwise operators have the same strategy as their previous operators.
|
|
|
|
@ -893,6 +895,209 @@ void AugmentCostGraph(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
bool FindReshape(const CNodePtr &cnode) {
|
|
|
|
|
if ((cnode == nullptr) || !IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
if (!IsParallelCareNode(cnode) || (cnode->operator_info() == nullptr)) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
PrimitivePtr prim = GetValueNode<PrimitivePtr>(prim_anf_node);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim);
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->operator_info();
|
|
|
|
|
if (operator_info == nullptr) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:Primitive " << prim->ToString() << " OperatorInstance is nullptr";
|
|
|
|
|
}
|
|
|
|
|
if (prim->name() != RESHAPE) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// find previous node, then obtain its strategy_cost_ vector to get its layout vector.
|
|
|
|
|
bool FindPreNodeStraCosts(const AnfNodePtr &node, OperatorInfoPtr *pre_operator_info, int32_t *out_index) {
|
|
|
|
|
// if previous node is a parameter, handle it in the outsize.
|
|
|
|
|
if (node->isa<Parameter>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (!node->isa<CNode>()) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
CNodePtr cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (!IsValueNode<Primitive>(cnode->input(0))) {
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(cnode) && (cnode->operator_info() != nullptr)) {
|
|
|
|
|
*pre_operator_info = cnode->operator_info();
|
|
|
|
|
*out_index = 0;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = cnode->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
PrimitivePtr prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
if (prim->name() == TUPLE_GETITEM) {
|
|
|
|
|
*out_index = GetTupleGetItemIndex(cnode);
|
|
|
|
|
// find tuple_get_item's previous node
|
|
|
|
|
auto pre_node = cnode->input(1);
|
|
|
|
|
if (!pre_node->isa<CNode>()) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "tuple get item's second input is not a cnode";
|
|
|
|
|
}
|
|
|
|
|
CNodePtr pre_cnode = pre_node->cast<CNodePtr>();
|
|
|
|
|
if (IsParallelCareNode(pre_cnode) && (pre_cnode->operator_info() != nullptr)) {
|
|
|
|
|
*pre_operator_info = pre_cnode->operator_info();
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
for (size_t index = 0; index < cnode->inputs().size(); ++index) {
|
|
|
|
|
if (prim->name() == DEPEND && index != 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (!FindPreNodeStraCosts(cnode->inputs()[index], pre_operator_info, out_index)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(WARNING) << "FindPreNodeStraCosts failed, if reshape is not the first primitive, there must be some error";
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// find next node, then obtain its strategy_cost_ vector to get its layout vector.
|
|
|
|
|
// if reshape's output connect to several primitive, return the first layout found
|
|
|
|
|
bool FindNextNodeStraCosts(const CNodePtr &cnode, OperatorInfoPtr *next_operator_info, int32_t *in_index) {
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode);
|
|
|
|
|
MS_EXCEPTION_IF_NULL(cnode->func_graph());
|
|
|
|
|
FuncGraphManagerPtr manager = cnode->func_graph()->manager();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(manager);
|
|
|
|
|
AnfNodeIndexSet node_set = manager->node_users()[cnode];
|
|
|
|
|
for (auto &node_pair : node_set) {
|
|
|
|
|
CNodePtr use_apply = node_pair.first->cast<CNodePtr>();
|
|
|
|
|
if (use_apply == nullptr || !IsValueNode<Primitive>(use_apply->input(0))) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
ValueNodePtr prim_anf_node = use_apply->input(0)->cast<ValueNodePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(prim_anf_node);
|
|
|
|
|
PrimitivePtr node_prim = prim_anf_node->value()->cast<PrimitivePtr>();
|
|
|
|
|
MS_EXCEPTION_IF_NULL(node_prim);
|
|
|
|
|
MS_LOG(INFO) << "FindNextLayout prim " << node_prim->name();
|
|
|
|
|
if (node_prim->name() == DEPEND && node_pair.second != 1) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
if (IsParallelCareNode(use_apply) && (use_apply->operator_info() != nullptr)) {
|
|
|
|
|
MS_LOG(INFO) << "FindNextNodeStraCosts success prim " << node_prim->name();
|
|
|
|
|
*next_operator_info = use_apply->operator_info();
|
|
|
|
|
*in_index = node_pair.second - 1;
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
MS_LOG(DEBUG) << "FindNextNodeStraCosts failed prim " << node_prim->name() << " " << IsParallelCareNode(use_apply)
|
|
|
|
|
<< " " << (use_apply->operator_info() != nullptr);
|
|
|
|
|
|
|
|
|
|
if (FindNextNodeStraCosts(use_apply, next_operator_info, in_index)) {
|
|
|
|
|
return true;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return false;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void InferStraByTensorInfo(const TensorInfo &pre_out_tensor_info, Dimensions *stra) {
|
|
|
|
|
Shape shape = pre_out_tensor_info.shape();
|
|
|
|
|
Shape slice_shape = pre_out_tensor_info.slice_shape();
|
|
|
|
|
for (size_t i = 0; i < shape.size(); ++i) {
|
|
|
|
|
if ((slice_shape[i] == 0) || (shape[i] % slice_shape[i] != 0)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "slice_shape is wrong in reshape operator";
|
|
|
|
|
}
|
|
|
|
|
int32_t dim = (int32_t)(shape[i] / slice_shape[i]);
|
|
|
|
|
(*stra).push_back(dim);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
void ReshapeCostCompute(const std::vector<AnfNodePtr> &all_nodes) {
|
|
|
|
|
for (auto node : all_nodes) {
|
|
|
|
|
auto cnode = node->cast<CNodePtr>();
|
|
|
|
|
if (!FindReshape(cnode)) {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
MS_ASSERT(cnode->inputs().size() == 3);
|
|
|
|
|
// get previous node's strategy_cost_
|
|
|
|
|
auto pre_node = cnode->input(1);
|
|
|
|
|
int32_t out_index = 0;
|
|
|
|
|
OperatorInfoPtr pre_operator_info;
|
|
|
|
|
std::vector<std::shared_ptr<StrategyWithCost>> pre_stra_costs;
|
|
|
|
|
if (pre_node->isa<Parameter>()) {
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->operator_info();
|
|
|
|
|
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
|
|
|
|
reshape_info->SetCostForReshapeWithParameter();
|
|
|
|
|
pre_operator_info = reshape_info;
|
|
|
|
|
pre_stra_costs = reshape_info->strategy_cost();
|
|
|
|
|
} else {
|
|
|
|
|
if (!FindPreNodeStraCosts(pre_node, &pre_operator_info, &out_index)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "FindPreNodeStraCosts for reshape failed";
|
|
|
|
|
}
|
|
|
|
|
pre_stra_costs = pre_operator_info->strategy_cost();
|
|
|
|
|
}
|
|
|
|
|
// get next node's strategy_cost_
|
|
|
|
|
int32_t in_index = 0;
|
|
|
|
|
OperatorInfoPtr next_operator_info;
|
|
|
|
|
std::vector<std::shared_ptr<StrategyWithCost>> next_stra_costs;
|
|
|
|
|
bool find_next_node = FindNextNodeStraCosts(cnode, &next_operator_info, &in_index);
|
|
|
|
|
if (!find_next_node) {
|
|
|
|
|
MS_LOG(INFO) << "FindNextNodeStraCosts for reshape failed";
|
|
|
|
|
}
|
|
|
|
|
// set input_layout and output_layout for reshape.
|
|
|
|
|
// init reshape and set cost for each input_layout and output_layout.
|
|
|
|
|
OperatorInfoPtr operator_info = cnode->operator_info();
|
|
|
|
|
auto reshape_info = std::dynamic_pointer_cast<ReshapeInfo>(operator_info);
|
|
|
|
|
reshape_info->set_pre_operator_name(pre_operator_info->name());
|
|
|
|
|
reshape_info->set_pre_operator_index(out_index);
|
|
|
|
|
if (find_next_node) {
|
|
|
|
|
next_stra_costs = next_operator_info->strategy_cost();
|
|
|
|
|
reshape_info->set_next_operator_name(next_operator_info->name());
|
|
|
|
|
reshape_info->set_next_operator_index(in_index);
|
|
|
|
|
}
|
|
|
|
|
for (auto pre_stra_cost : pre_stra_costs) {
|
|
|
|
|
std::vector<TensorInfo> pre_out_tensor_infos;
|
|
|
|
|
if (pre_node->isa<Parameter>()) {
|
|
|
|
|
pre_out_tensor_infos = pre_stra_cost->inputs_ptr;
|
|
|
|
|
} else {
|
|
|
|
|
pre_out_tensor_infos = pre_stra_cost->outputs_ptr;
|
|
|
|
|
}
|
|
|
|
|
if (pre_out_tensor_infos.size() <= IntToSize(out_index)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "out_index is out of range of the tensor_infos in setting reshape's input_layout";
|
|
|
|
|
}
|
|
|
|
|
TensorInfo pre_out_tensor_info = pre_out_tensor_infos[out_index];
|
|
|
|
|
TensorLayout pre_out_tensor_layout = pre_out_tensor_info.tensor_layout();
|
|
|
|
|
reshape_info->SetInputLayout(pre_out_tensor_layout);
|
|
|
|
|
// infer pre_node output strategy from output_layout.
|
|
|
|
|
Dimensions stra;
|
|
|
|
|
InferStraByTensorInfo(pre_out_tensor_info, &stra);
|
|
|
|
|
std::vector<Dimensions> stra_inputs = {stra};
|
|
|
|
|
StrategyPtr reshape_stra = std::make_shared<Strategy>(pre_stra_cost->strategy_ptr->GetInputStage(), stra_inputs);
|
|
|
|
|
if (next_stra_costs.empty()) {
|
|
|
|
|
if (reshape_info->Init(nullptr) == FAILED) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
|
|
|
|
}
|
|
|
|
|
// set cost for each input_layout and output_layout pairs.
|
|
|
|
|
reshape_info->SetCostForReshape(reshape_stra);
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
for (auto next_stra_cost : next_stra_costs) {
|
|
|
|
|
std::vector<TensorInfo> next_in_tensor_infos = next_stra_cost->inputs_ptr;
|
|
|
|
|
if (next_in_tensor_infos.size() <= IntToSize(in_index)) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "in_index is out of range of the tensor_infos in setting reshape's output_layout";
|
|
|
|
|
}
|
|
|
|
|
TensorInfo next_in_tensor_info = next_in_tensor_infos[in_index];
|
|
|
|
|
TensorLayout next_in_tensor_layout = next_in_tensor_info.tensor_layout();
|
|
|
|
|
reshape_info->SetOutputLayout(next_in_tensor_layout);
|
|
|
|
|
if (reshape_info->Init(nullptr) == FAILED) {
|
|
|
|
|
MS_LOG(EXCEPTION) << "Failure:operator reshape init failed";
|
|
|
|
|
}
|
|
|
|
|
// set cost for each input_layout and output_layout pairs.
|
|
|
|
|
reshape_info->SetCostForReshape(reshape_stra);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const FuncGraphPtr &root) {
|
|
|
|
|
// There are 4 meta-steps to determine the parallelization strategy for the ANF graph.
|
|
|
|
|
// Step 1: Traverse the ANF graph, and create NODEs for costgraph:
|
|
|
|
@ -930,7 +1135,9 @@ Status ParallelStrategySearch(const std::vector<AnfNodePtr> &all_nodes, const Fu
|
|
|
|
|
MS_LOG(EXCEPTION) << "Constructing nodes for cost graph failed.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// reshape operator needs the next node's input_layout as its output_layout.
|
|
|
|
|
// and needs the previous node's output_layout as its input_layout.
|
|
|
|
|
ReshapeCostCompute(all_nodes);
|
|
|
|
|
// Step 2
|
|
|
|
|
ConstructCostGraphEdges(all_nodes);
|
|
|
|
|
MS_LOG(INFO) << "Constructing edges for cost graph succeeded. There are " << entire_costgraph->GetOperators().size()
|
|
|
|
|