diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 7a25f44f99..fd09b5e0b5 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -1683,7 +1683,10 @@ CNodePtr FindLossCNode(const FuncGraphPtr &func_graph) { MS_EXCEPTION_IF_NULL(pre_node); auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); + if (pre_cnode == nullptr) { + return nullptr; + } + auto current_prim = GetValueNode(pre_cnode->input(0)); // return -> cast if (current_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { @@ -1907,21 +1910,6 @@ void StepSplitSens(const std::pair &sens_loss_pair) { } } -std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root) { - MS_EXCEPTION_IF_NULL(root); - AnfNodePtr root_return_node = root->get_return(); - MS_EXCEPTION_IF_NULL(root_return_node); - std::vector loss_node; - const auto &all_nodes = root->nodes(); - std::set graph_set = FindForwardGraphByRootNodes(all_nodes); - if (graph_set.empty()) { - loss_node.push_back(FindLossCNode(root)); - } - (void)std::transform(graph_set.begin(), graph_set.end(), std::back_inserter(loss_node), - [](const FuncGraphPtr &graph) { return FindLossCNode(graph); }); - return loss_node; -} - // Sens node satisfies the following conditions: cnode(sens)-->cnode(tuple_getitem)-->cnode-->cnode(J) std::vector> GetSensLossPairs(const FuncGraphPtr &root) { MS_EXCEPTION_IF_NULL(root); @@ -1968,6 +1956,10 @@ std::vector> GetSensLossPairs(const FuncGraphPtr & } auto func_graph = GetValueNode(expect_j_cnode->input(1)); auto loss_cnode = FindLossCNode(func_graph); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + continue; + } std::pair sens_loss_pair = std::make_pair(sens_cnode, loss_cnode); sens_loss_pairs.push_back(sens_loss_pair); } @@ -2158,10 +2150,14 @@ std::set ForwardGraph(const FuncGraphPtr &root) { std::vector FindRootForwardCNode(const FuncGraphPtr &graph, const AnfNodeSet &all_nodes) { MS_EXCEPTION_IF_NULL(graph); + std::vector root_forward_nodes; auto loss_cnode = FindLossCNode(graph); - MS_EXCEPTION_IF_NULL(loss_cnode); + if (loss_cnode == nullptr) { + MS_LOG(WARNING) << "Can not find the loss cnode"; + return root_forward_nodes; + } + auto loss_cnode_id = loss_cnode->UniqueIdThroughCopy(); - std::vector root_forward_nodes; for (auto &node : all_nodes) { MS_EXCEPTION_IF_NULL(node); if (!node->isa()) { diff --git a/mindspore/ccsrc/parallel/step_parallel.h b/mindspore/ccsrc/parallel/step_parallel.h index 93c3ed798c..308473dcd7 100644 --- a/mindspore/ccsrc/parallel/step_parallel.h +++ b/mindspore/ccsrc/parallel/step_parallel.h @@ -144,8 +144,6 @@ bool StepParallel(const FuncGraphPtr &func_graph, const opt::OptimizerPtr &optim int32_t GetTupleGetItemIndex(const CNodePtr &cnode); -std::vector FindLossCNodeFromRoot(const FuncGraphPtr &root); - Status ParallelInit(); std::vector ExtractInputsTensorName(const CNodePtr &node); diff --git a/tests/ut/python/parallel/test_reshape_optimized.py b/tests/ut/python/parallel/test_reshape_optimized.py new file mode 100644 index 0000000000..74b4c0024d --- /dev/null +++ b/tests/ut/python/parallel/test_reshape_optimized.py @@ -0,0 +1,54 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np + +import mindspore as ms +from mindspore import context, Tensor, Parameter +from mindspore.common.api import _executor +from mindspore.nn import Cell, TrainOneStepCell, Momentum +from mindspore.ops import operations as P + + +class Net(Cell): + def __init__(self, mul_weight): + super().__init__() + self.reshape1 = P.Reshape() + self.reshape2 = P.Reshape() + self.mul_weight = Parameter(mul_weight, "w1") + + def construct(self, x, b): + out = self.reshape1(self.mul_weight, (128, 64, 32)) + out = self.reshape2(out, (128, 64, 32)) + return out + + +_x = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_w1 = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) +_b = Tensor(np.ones([128, 64, 32]), dtype=ms.float32) + + +def compile_net(net): + context.set_context(save_graphs=True) + optimizer = Momentum(net.trainable_params(), learning_rate=0.1, momentum=0.9) + train_net = TrainOneStepCell(net, optimizer) + train_net.set_auto_parallel() + _executor.compile(train_net, _x, _b) + context.reset_auto_parallel_context() + + +def test_reshape_optimized(): + context.set_auto_parallel_context(parallel_mode="auto_parallel", device_num=16, global_rank=0) + net = Net(_w1) + compile_net(net)