|
|
|
@ -412,16 +412,11 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|
|
|
|
MS_EXCEPTION_IF_NULL(ops[iter_ops]);
|
|
|
|
|
|
|
|
|
|
auto type = ops[iter_ops]->type();
|
|
|
|
|
auto idx = DictOpType.find(type);
|
|
|
|
|
if (idx == DictOpType.end()) {
|
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if (type == MATMUL) {
|
|
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == ONEHOT) {
|
|
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
|
|
|
|
} else if ((type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) || (type == "_VirtualDataset")) {
|
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else {
|
|
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|