fix onehot axis

pull/2530/head
hongxing 5 years ago
parent 8e20d4d84e
commit 7029bc5dd3

@ -27,10 +27,10 @@
namespace mindspore { namespace mindspore {
namespace parallel { namespace parallel {
void GenerateStrategy(std::shared_ptr<Graph> graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateStrategy(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareMatMul(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
@ -50,12 +50,12 @@ std::vector<std::vector<int32_t>> MakeDataParallelStrategy(const std::shared_ptr
std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph, std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> graph, void GeneratePartitionedOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names, size_t FindIndexOfOperatorIncoming(const std::vector<std::vector<std::string>> &input_tensor_names,
const size_t iter_ops); const size_t iter_ops);
std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> graph, std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_graph); const size_t iter_ops, const size_t iter_graph);
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
@ -63,19 +63,23 @@ std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std:
std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops); std::vector<int32_t> GetAxisList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const int iter_ops);
std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
bool GetKeepDims(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops); std::vector<int32_t> GetDimList(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s); const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> GetDimListFromAttrs(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops);
std::vector<int32_t> ModifyStrategyIfArgIncoming(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t incoming_op_index, std::vector<int32_t> s);
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t incoming_op_index); const size_t iter_ops, const size_t incoming_op_index);
std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, const size_t iter_ops,
std::vector<int32_t> basic_stra); std::vector<int32_t> basic_stra);
void GenerateEliminatedOperatorStrategyForward(std::shared_ptr<Graph> graph, void GenerateEliminatedOperatorStrategyForward(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> ModifyStrategyIfSqueezeOutgoing(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_ops, std::vector<int32_t> s); const size_t iter_ops, std::vector<int32_t> s);
std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
@ -83,12 +87,12 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
const size_t iter_ops); const size_t iter_ops);
void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops, void GenerateEliminatedOperatorStrategyBackward(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> graph, void GenerateRemainingOperatorStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const std::vector<std::vector<std::string>> &input_tensor_names, const std::vector<std::vector<std::string>> &input_tensor_names,
const std::shared_ptr<std::vector<size_t>> index_list, const std::shared_ptr<std::vector<size_t>> &index_list,
const std::shared_ptr<std::vector<size_t>> no_stra_op_list); const std::shared_ptr<std::vector<size_t>> &no_stra_op_list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_GENERATE_STRATEGY_H_

@ -50,7 +50,8 @@ enum OperatorType {
kRecCast, kRecCast,
kRecReduce, kRecReduce,
kRecPReLU, kRecPReLU,
kRecGatherV2 kRecGatherV2,
kRecArgWithValue
}; };
enum InfoType { kApplication, kConstant }; enum InfoType { kApplication, kConstant };

@ -163,8 +163,8 @@ size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &i
return SIZE_MAX; return SIZE_MAX;
} }
void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph, void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list) { const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list) {
std::vector<size_t> eli; std::vector<size_t> eli;
eli.push_back(node_index); eli.push_back(node_index);
for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) { for (size_t i = 0; i < (size_t)graph->nodes[node_index].node_out.size(); i++) {
@ -211,18 +211,18 @@ void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph,
} }
} }
std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> index_list) { const std::shared_ptr<std::vector<size_t>> &index_list) {
MS_EXCEPTION_IF_NULL(graph); MS_EXCEPTION_IF_NULL(graph);
const std::set<OperatorType> type_list = { const std::set<OperatorType> elementwise_type = {
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd, OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul, OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast, OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
OperatorType::kRecReshape, OperatorType::kRecGatherV2}; OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) { for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
auto type = graph->nodes[node_index].apply.op_type; auto type = graph->nodes[node_index].apply.op_type;
if (type_list.find(type) != type_list.end()) { if (elementwise_type.find(type) != elementwise_type.end()) {
Eliminate_Aux(node_index, graph, eli_list); Eliminate_Aux(node_index, graph, eli_list);
} }
} }

@ -47,6 +47,8 @@ const std::map<std::string, OperatorType> DictOpType{
{REDUCE_MIN, OperatorType::kRecReduce}, {REDUCE_MIN, OperatorType::kRecReduce},
{REDUCE_MEAN, OperatorType::kRecReduce}, {REDUCE_MEAN, OperatorType::kRecReduce},
{GATHERV2, OperatorType::kRecGatherV2}, {GATHERV2, OperatorType::kRecGatherV2},
{ARGMAXWITHVALUE, OperatorType::kRecArgWithValue},
{ARGMINWITHVALUE, OperatorType::kRecArgWithValue},
{RELU, OperatorType::kRecReLU}, {RELU, OperatorType::kRecReLU},
{"ReLU6", OperatorType::kRecReLU}, {"ReLU6", OperatorType::kRecReLU},
@ -59,6 +61,7 @@ const std::map<std::string, OperatorType> DictOpType{
{PRELU, OperatorType::kRecPReLU}, {PRELU, OperatorType::kRecPReLU},
{TRANSPOSE, OperatorType::kRecElmWiseOp},
{L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
@ -123,12 +126,12 @@ void MakeEdge(const std::vector<std::vector<std::string>> &input_tensor_names, s
size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names, size_t GetIndexInInputTensorNames(const std::vector<std::vector<std::string>> &input_tensor_names,
const std::string &input_name); const std::string &input_name);
void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> graph, void Eliminate_Aux(const size_t node_index, const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list); const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list);
std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> graph, std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
const std::shared_ptr<std::vector<std::vector<size_t>>> eli_list, const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
const std::shared_ptr<std::vector<size_t>> index_list); const std::shared_ptr<std::vector<size_t>> &index_list);
} // namespace parallel } // namespace parallel
} // namespace mindspore } // namespace mindspore
#endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_ #endif // PARALLEL_AUTO_PARALLEL_REC_PARSE_GRAPH_H_

Loading…
Cancel
Save