support squeeze and reduce op

pull/1152/head
hongxing 5 years ago
parent 6b68671805
commit dc290d7959

@ -31,6 +31,11 @@ void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::share
std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph, std::vector<int32_t> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_nodes,
const size_t iter_op_inputs); const size_t iter_op_inputs);
std::vector<std::vector<int32_t>> PrepareVirtualDataset(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops);
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s);
std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> MakeRecSearchStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<Graph> &graph, const size_t iter_ops, const std::shared_ptr<Graph> &graph, const size_t iter_ops,
const size_t iter_op_inputs); const size_t iter_op_inputs);
@ -39,6 +44,24 @@ std::vector<int32_t> MakeDataParallelStrategy(const std::vector<std::shared_ptr<
std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph, std::vector<int32_t> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops,
const size_t iter_op_inputs); const size_t iter_op_inputs);
int FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index);
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const int incoming_op_index, std::vector<int32_t> s);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> ModifyStrategyIfReduceOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_

@ -140,6 +140,7 @@ class OperatorInfo {
CostPtr selected_cost() const { return selected_cost_; } CostPtr selected_cost() const { return selected_cost_; }
Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); } Status InitSelectedStrategy(const StrategyPtr &s_strategy) { return Init(s_strategy); }
void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; } void set_input_value(const std::vector<ValuePtr> &input_value) { input_value_ = input_value; }
const std::vector<ValuePtr> &input_value() const { return input_value_; }
void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; } void set_outputs_dtype(const TypePtr &dtype) { outputs_dtype_ = dtype; }
void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; } void set_cnode(const CNodePtr &cnode) { cnode_ = cnode; }
bool is_alive() const { return is_alive_; } bool is_alive() const { return is_alive_; }

Loading…
Cancel
Save