|
|
|
@ -135,24 +135,51 @@ std::vector<std::vector<int32_t>> PreparePReLU(const std::shared_ptr<Graph> &gra
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareBiasAdd(std::vector<int32_t> s) {
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareBatchNorm(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|
const size_t iter_graph, const size_t iter_ops) {
|
|
|
|
|
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
for (size_t i = 1; i < strategies.size(); i++) {
|
|
|
|
|
strategies[i][0] = strategies[0][1];
|
|
|
|
|
}
|
|
|
|
|
strategies[1][0] = 1;
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareSoftmaxWithLogits(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|
const size_t iter_graph, const size_t iter_ops) {
|
|
|
|
|
std::vector<std::vector<int32_t>> strategies = MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_w = graph->nodes[iter_graph].tensor_parm.tensor_str.str_h;
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_h = graph->nodes[iter_graph].tensor_parm.tensor_str.str_c;
|
|
|
|
|
graph->nodes[iter_graph].tensor_parm.tensor_str.str_c = graph->nodes[iter_graph].tensor_parm.tensor_str.str_n;
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareBiasAdd(const std::shared_ptr<std::vector<int32_t>> &s) {
|
|
|
|
|
std::vector<std::vector<int32_t>> strategies;
|
|
|
|
|
strategies.push_back(s);
|
|
|
|
|
strategies.push_back(*s);
|
|
|
|
|
std::vector<int32_t> s_biasadd;
|
|
|
|
|
s_biasadd.push_back(s[1]);
|
|
|
|
|
s_biasadd.push_back(s->at(1));
|
|
|
|
|
strategies.push_back(s_biasadd);
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareOneHot(std::vector<int32_t> s) {
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareOneHot(const std::shared_ptr<std::vector<int32_t>> &s) {
|
|
|
|
|
std::vector<std::vector<int32_t>> strategies;
|
|
|
|
|
std::vector<int32_t> s_empty = {};
|
|
|
|
|
strategies.push_back(s);
|
|
|
|
|
strategies.push_back(*s);
|
|
|
|
|
strategies.push_back(s_empty);
|
|
|
|
|
strategies.push_back(s_empty);
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> PrepareGatherV2(const std::shared_ptr<std::vector<int32_t>> &s) {
|
|
|
|
|
std::vector<std::vector<int32_t>> strategies;
|
|
|
|
|
strategies.push_back(*s);
|
|
|
|
|
return strategies;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<std::vector<int32_t>> MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|
const size_t iter_graph, const size_t iter_ops) {
|
|
|
|
@ -270,6 +297,12 @@ std::vector<std::vector<int32_t>> PrepareStrategy(const std::shared_ptr<Graph> &
|
|
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == PRELU) {
|
|
|
|
|
return PreparePReLU(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == BATCH_NORM) {
|
|
|
|
|
return PrepareBatchNorm(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
|
|
|
|
return PrepareSoftmaxWithLogits(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == SOFTMAX || type == LOG_SOFTMAX || type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else {
|
|
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
}
|
|
|
|
@ -336,7 +369,7 @@ std::vector<int32_t> CopyIncomingOperatorOutputStrategy(const std::shared_ptr<Gr
|
|
|
|
|
std::vector<int32_t> PrepareIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|
const size_t incoming_op_index) {
|
|
|
|
|
std::vector<int32_t> s;
|
|
|
|
|
if (ops[incoming_op_index]->type() == RESHAPE) {
|
|
|
|
|
if (ops[incoming_op_index]->type() == RESHAPE || ops[incoming_op_index]->type() == GATHERV2) {
|
|
|
|
|
return s;
|
|
|
|
|
}
|
|
|
|
|
auto strategy = ops[incoming_op_index]->selected_strategy();
|
|
|
|
@ -456,11 +489,6 @@ std::vector<int32_t> ModifyStrategyIfReduceIncoming(const std::vector<std::share
|
|
|
|
|
return s_Reduce;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> ModifyStrategyIfSoftmaxIncoming(std::vector<int32_t> s) {
|
|
|
|
|
s.pop_back();
|
|
|
|
|
return s;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,
|
|
|
|
|
const size_t iter_ops, const size_t incoming_op_index) {
|
|
|
|
|
std::vector<int32_t> s;
|
|
|
|
@ -474,9 +502,6 @@ std::vector<int32_t> CopyIncomingOperatorInputStrategy(const std::vector<std::sh
|
|
|
|
|
ops[incoming_op_index]->type() == REDUCE_MIN || ops[incoming_op_index]->type() == REDUCE_MEAN) {
|
|
|
|
|
s = ModifyStrategyIfReduceIncoming(ops, incoming_op_index, s);
|
|
|
|
|
}
|
|
|
|
|
if (ops[incoming_op_index]->type() == SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
|
|
|
|
s = ModifyStrategyIfSoftmaxIncoming(s);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return s;
|
|
|
|
|
}
|
|
|
|
@ -496,11 +521,15 @@ std::vector<std::vector<int32_t>> GenerateStrategiesFromStrategy(const std::vect
|
|
|
|
|
return stra;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
auto s_ptr = std::make_shared<std::vector<int32_t>>(basic_stra);
|
|
|
|
|
if (ops[iter_ops]->type() == BIAS_ADD) {
|
|
|
|
|
return PrepareBiasAdd(basic_stra);
|
|
|
|
|
return PrepareBiasAdd(s_ptr);
|
|
|
|
|
}
|
|
|
|
|
if (ops[iter_ops]->type() == ONEHOT) {
|
|
|
|
|
return PrepareOneHot(basic_stra);
|
|
|
|
|
return PrepareOneHot(s_ptr);
|
|
|
|
|
}
|
|
|
|
|
if (ops[iter_ops]->type() == GATHERV2) {
|
|
|
|
|
return PrepareGatherV2(s_ptr);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
for (size_t iter_op_inputs = 0; iter_op_inputs < (size_t)ops[iter_ops]->inputs_tensor_info().size();
|
|
|
|
@ -599,7 +628,8 @@ std::vector<int32_t> CopyOutgoingOperatorInputStrategy(const std::vector<std::sh
|
|
|
|
|
const size_t iter_ops) {
|
|
|
|
|
std::vector<int32_t> s;
|
|
|
|
|
if (ops[iter_ops]->type() == REDUCE_MAX || ops[iter_ops]->type() == REDUCE_MIN ||
|
|
|
|
|
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE) {
|
|
|
|
|
ops[iter_ops]->type() == REDUCE_SUM || ops[iter_ops]->type() == REDUCE_MEAN || ops[iter_ops]->type() == RESHAPE ||
|
|
|
|
|
ops[iter_ops]->type() == GATHERV2) {
|
|
|
|
|
return s;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|