support user defined strategy

pull/10209/head
sheng 4 years ago
parent 74b03da452
commit dbab352861

@ -25,6 +25,7 @@
#include "frontend/parallel/auto_parallel/rec_core/rec_partition.h"
#include "frontend/parallel/ops_info/operator_info.h"
#include "frontend/parallel/strategy.h"
#include "frontend/parallel/step_parallel.h"
namespace mindspore {
namespace parallel {
@ -43,6 +44,14 @@ void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std
GenerateEliminatedOperatorStrategyForward(graph, ops, input_tensor_names, index_list, no_stra_op_list);
GenerateEliminatedOperatorStrategyBackward(ops, input_tensor_names, no_stra_op_list);
GenerateRemainingOperatorStrategy(graph, ops, input_tensor_names, index_list, no_stra_op_list);
for (auto &op : ops) {
auto attrs = op->attrs();
if (StrategyFound(attrs)) {
StrategyPtr user_defined_stra = parallel::ExtractStrategy(attrs);
op->SetSelectedStrategyAndCost(user_defined_stra, op->selected_cost());
}
}
}
Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,

Loading…
Cancel
Save