|
|
|
@ -25,6 +25,7 @@
|
|
|
|
|
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
|
|
|
|
|
#include "frontend/parallel/ops_info/operator_info.h"
|
|
|
|
|
#include "frontend/parallel/strategy.h"
|
|
|
|
|
#include "frontend/parallel/step_parallel.h"
|
|
|
|
|
|
|
|
|
|
namespace mindspore {
|
|
|
|
|
namespace parallel {
|
|
|
|
@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
|
|
|
|
|
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
|
|
|
|
|
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
|
|
|
|
|
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
|
|
|
|
|
|
|
|
|
|
for (auto &op : ops) {
|
|
|
|
|
auto attrs = op->attrs();
|
|
|
|
|
if (StrategyFound(attrs)) {
|
|
|
|
|
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs);
|
|
|
|
|
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|