|
|
|
@ -368,14 +368,19 @@ Strategys MakeDataParallelStrategy(const std::shared_ptr<Graph> &graph,
|
|
|
|
|
for (size_t dim = 0; dim < input_size; dim++) {
|
|
|
|
|
if (input_size == 1 || input_size == 2 || input_size == 4) {
|
|
|
|
|
if (dim == 0) {
|
|
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch));
|
|
|
|
|
// Currently GPU version does not support partitioning ‘FusedBatchNormEx’ in its param tensors.
|
|
|
|
|
if (ops[iter_ops]->type() == "FusedBatchNormEx" && iter_op_inputs != 0) {
|
|
|
|
|
s.push_back(1);
|
|
|
|
|
} else {
|
|
|
|
|
s.push_back(std::min(max_device_num, target_tensor_batch));
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
s.push_back(1);
|
|
|
|
|
}
|
|
|
|
|
} else if (input_size == 0) {
|
|
|
|
|
s = {};
|
|
|
|
|
} else {
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor's shape is unknown.";
|
|
|
|
|
MS_LOG(EXCEPTION) << ops[iter_ops]->name() << ": Tensor shape " << input_size << " is unexpected.";
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
strategies.push_back(s);
|
|
|
|
@ -416,6 +421,8 @@ Strategys PrepareStrategy(const std::shared_ptr<Graph> &graph, const std::vector
|
|
|
|
|
return PrepareMatMul(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == ONEHOT) {
|
|
|
|
|
return PrepareOneHot(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else if (type == SPARSE_SOFTMAX_CROSS_ENTROPY_WITH_LOGITS) {
|
|
|
|
|
return MakeDataParallelStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
} else {
|
|
|
|
|
return MakeRecSearchStrategy(graph, ops, iter_graph, iter_ops);
|
|
|
|
|
}
|
|
|
|
|