|
|
@ -300,6 +300,38 @@ Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &o
|
|
|
|
return strategies;
|
|
|
|
return strategies;
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
|
|
|
|
|
|
|
const size_t iter_ops) {
|
|
|
|
|
|
|
|
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
|
|
|
if (strategies.size() < 1) {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
int64_t axis = -1;
|
|
|
|
|
|
|
|
auto iter = ops[iter_ops]->attrs().find(AXIS);
|
|
|
|
|
|
|
|
if (iter != ops[iter_ops]->attrs().end()) {
|
|
|
|
|
|
|
|
MS_EXCEPTION_IF_NULL(iter->second);
|
|
|
|
|
|
|
|
if (iter->second->isa<Int64Imm>()) {
|
|
|
|
|
|
|
|
axis = iter->second->cast<Int64ImmPtr>()->value();
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (axis < 0) {
|
|
|
|
|
|
|
|
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
|
|
|
|
|
|
|
|
axis = input_dim + axis;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (strategies[0][axis] != 1) {
|
|
|
|
|
|
|
|
strategies[0][axis] = 1;
|
|
|
|
|
|
|
|
MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return strategies;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
|
|
|
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
|
|
|
|
const size_t iter_ops) {
|
|
|
|
const size_t iter_ops) {
|
|
|
@ -437,6 +469,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
|
|
|
} else if (type == ONEHOT) {
|
|
|
|
} else if (type == ONEHOT) {
|
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
|
|
|
} else if (type == SOFTMAX) {
|
|
|
|
|
|
|
|
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
|
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
|
|
|
|
(type == "FusedBatchNormEx") || (type == "Dropout")) {
|
|
|
|
(type == "FusedBatchNormEx") || (type == "Dropout")) {
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|