!8851 [AutoParallel] restraint parallelable dimension of the Softmax operator

From: @ch-l
Reviewed-by: @stsuteng
Signed-off-by: @stsuteng
pull/8851/MERGE
mindspore-ci-bot 4 years ago committed by Gitee
commit 729801f8fa

@ -300,6 +300,38 @@ Strategys PrepareL2Normalize(const std::vector<std::shared_ptr<OperatorInfo>> &o
return strategies; return strategies;
} }
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) {
Strategys strategies = MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
if (strategies.size() < 1) {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": get empty Strategy.";
}
int64_t axis = -1;
auto iter = ops[iter_ops]->attrs().find(AXIS);
if (iter != ops[iter_ops]->attrs().end()) {
MS_EXCEPTION_IF_NULL(iter->second);
if (iter->second->isa<Int64Imm>()) {
axis = iter->second->cast<Int64ImmPtr>()->value();
} else {
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": The value of axis is not int64_t.";
}
}
if (axis < 0) {
int64_t input_dim = SizeToLong(ops[iter_ops]->inputs_tensor_info()[0].shape().size());
axis = input_dim + axis;
}
if (strategies[0][axis] != 1) {
strategies[0][axis] = 1;
MS_LOG(INFO) << ops[iter_ops]->name() << ": adjust strategy to 1 on axis " << axis;
}
return strategies;
}
Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph, Strategys MakeRecSearchStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops) { const size_t iter_ops) {
@ -437,6 +469,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
return PrepareMatMul(graph, ops, iter_graph, iter_ops); return PrepareMatMul(graph, ops, iter_graph, iter_ops);
} else if (type == ONEHOT) { } else if (type == ONEHOT) {
return PrepareOneHot(graph, ops, iter_graph, iter_ops); return PrepareOneHot(graph, ops, iter_graph, iter_ops);
} else if (type == SOFTMAX) {
return PrepareAxisRelatedStrategy(graph, ops, iter_graph, iter_ops);
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") || } else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset") ||
(type == "FusedBatchNormEx") || (type == "Dropout")) { (type == "FusedBatchNormEx") || (type == "Dropout")) {
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops); return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);

@ -36,6 +36,9 @@ Strategys PrepareMatMul(const std::shared_ptr<Graph> &graph, const std::vector<s
Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s); Strategys PrepareBiasAdd(const std::shared_ptr<Dimensions> &s);
Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops, Strategys PrepareOneHot(const std::shared_ptr<Graph> &graph, const std::vector<std::shared_ptr<OperatorInfo>> &ops,
const size_t iter_graph, const size_t iter_ops); const size_t iter_graph, const size_t iter_ops);
Strategys PrepareAxisRelatedStrategy(const std::shared_ptr<Graph> &graph,
const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_graph,
const size_t iter_ops);
Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategys PrepareGatherV2(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s); Strategys PrepareGatherV2P(const std::vector<std::shared_ptr<OperatorInfo>> &ops, const size_t iter_ops, Dimensions s);
Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops, Dimensions PrepareGatherV2POutputStrategy(const std::vector<std::shared_ptr<OperatorInfo>> &ops,

@ -73,9 +73,9 @@ const std::map<std::string, OperatorType> DictOpType{
{PRELU, OperatorType::kRecPReLU}, {PRELU, OperatorType::kRecPReLU},
// Elm-wise OP // Elm-wise OP
{TRANSPOSE, OperatorType::kRecElmWiseOp}, {TRANSPOSE, OperatorType::kRecElmWiseOp},
{TRANSPOSE, OperatorType::kRecElmWiseOp},
{L2_NORMALIZE, OperatorType::kRecElmWiseOp}, {L2_NORMALIZE, OperatorType::kRecElmWiseOp},
{TENSOR_ADD, OperatorType::kRecElmWiseOp}, {TENSOR_ADD, OperatorType::kRecElmWiseOp},
{TENSOR_DOT, OperatorType::kRecElmWiseOp},
{SUB, OperatorType::kRecElmWiseOp}, {SUB, OperatorType::kRecElmWiseOp},
{MUL, OperatorType::kRecElmWiseOp}, {MUL, OperatorType::kRecElmWiseOp},
{DIV, OperatorType::kRecElmWiseOp}, {DIV, OperatorType::kRecElmWiseOp},

Loading…
Cancel
Save