|
|
@ -20,7 +20,6 @@
|
|
|
|
#include <memory>
|
|
|
|
#include <memory>
|
|
|
|
#include <string>
|
|
|
|
#include <string>
|
|
|
|
#include <vector>
|
|
|
|
#include <vector>
|
|
|
|
#include <set>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#include "ir/value.h"
|
|
|
|
#include "ir/value.h"
|
|
|
|
#include "parallel/auto_parallel/rec_core/rec_graph.h"
|
|
|
|
#include "parallel/auto_parallel/rec_core/rec_graph.h"
|
|
|
@ -215,23 +214,16 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
|
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
|
|
|
|
const std::shared_ptr<std::vector<std::vector<size_t>>> &eli_list,
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list) {
|
|
|
|
const std::shared_ptr<std::vector<size_t>> &index_list) {
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
MS_EXCEPTION_IF_NULL(graph);
|
|
|
|
static const std::set<OperatorType> elementwise_type = {
|
|
|
|
|
|
|
|
OperatorType::kRecReLU, OperatorType::kRecLog, OperatorType::kRecExp, OperatorType::kRecAdd,
|
|
|
|
|
|
|
|
OperatorType::kRecElmWiseOp, OperatorType::kRecBiasAdd, OperatorType::kRecSub, OperatorType::kRecMul,
|
|
|
|
|
|
|
|
OperatorType::kRecDiv, OperatorType::kRecSqueeze, OperatorType::kRecReduce, OperatorType::kRecCast,
|
|
|
|
|
|
|
|
OperatorType::kRecReshape, OperatorType::kRecGatherV2, OperatorType::kRecArgWithValue};
|
|
|
|
|
|
|
|
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
|
|
|
|
for (size_t node_index = 0; node_index < (size_t)graph->nodes.size(); node_index++) {
|
|
|
|
auto type = graph->nodes[node_index].apply.op_type;
|
|
|
|
auto type = graph->nodes[node_index].apply.op_type;
|
|
|
|
if (elementwise_type.find(type) != elementwise_type.end()) {
|
|
|
|
if (ElementWiseOpType.find(type) != ElementWiseOpType.end()) {
|
|
|
|
Eliminate_Aux(node_index, graph, eli_list);
|
|
|
|
Eliminate_Aux(node_index, graph, eli_list);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
index_list->reserve(graph->nodes.size());
|
|
|
|
index_list->reserve(graph->nodes.size());
|
|
|
|
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
|
|
|
|
for (size_t i = 0; i < (size_t)graph->nodes.size(); i++) {
|
|
|
|
index_list->push_back(i);
|
|
|
|
index_list->push_back(i);
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
|
|
|
|
for (size_t i = 0; i < (size_t)eli_list->size(); i++) {
|
|
|
|
if (eli_list->at(i)[0] >= index_list->size()) {
|
|
|
|
if (eli_list->at(i)[0] >= index_list->size()) {
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
|
|
|
|
MS_LOG(EXCEPTION) << "Failure: Operators' elements out of range.";
|
|
|
@ -241,13 +233,11 @@ std::shared_ptr<Graph> EliminateGraph(const std::shared_ptr<Graph> &graph,
|
|
|
|
index_list->at(j)--;
|
|
|
|
index_list->at(j)--;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
std::shared_ptr<Graph> new_graph(new Graph);
|
|
|
|
std::shared_ptr<Graph> new_graph(new Graph);
|
|
|
|
for (size_t i = 0; i < graph->nodes.size(); i++) {
|
|
|
|
for (size_t i = 0; i < graph->nodes.size(); i++) {
|
|
|
|
if (index_list->at(i) > SIZE_MAX / 2) {
|
|
|
|
if (index_list->at(i) > SIZE_MAX / 2) {
|
|
|
|
continue;
|
|
|
|
continue;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
new_graph->nodes.push_back(graph->nodes[i]);
|
|
|
|
new_graph->nodes.push_back(graph->nodes[i]);
|
|
|
|
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
|
|
|
|
auto *node_in = &new_graph->nodes[index_list->at(i)].node_in;
|
|
|
|
for (size_t j = node_in->size(); j > 0; j--) {
|
|
|
|
for (size_t j = node_in->size(); j > 0; j--) {
|
|
|
|