diff --git a/mindspore/ccsrc/parallel/step_auto_parallel.cc b/mindspore/ccsrc/parallel/step_auto_parallel.cc index 50e6a1e84e..7a895a9458 100644 --- a/mindspore/ccsrc/parallel/step_auto_parallel.cc +++ b/mindspore/ccsrc/parallel/step_auto_parallel.cc @@ -350,8 +350,6 @@ bool IsAutoParallelCareNode(const CNodePtr &cnode) { } OperatorInfoPtr CreateTheOperatorInfo(const PrimitivePtr &prim, const CNodePtr &cnode) { - MS_EXCEPTION_IF_NULL(prim); - MS_EXCEPTION_IF_NULL(cnode); auto attrs = prim->attrs(); std::vector shape_list = ExtractShape(cnode); if (shape_list.empty()) { diff --git a/mindspore/ccsrc/parallel/step_parallel.cc b/mindspore/ccsrc/parallel/step_parallel.cc index 31dc77b595..af5eb0159f 100644 --- a/mindspore/ccsrc/parallel/step_parallel.cc +++ b/mindspore/ccsrc/parallel/step_parallel.cc @@ -374,6 +374,7 @@ bool IsParallelCareNode(const CNodePtr& cnode) { if (prim == nullptr) { return false; } + auto attrs = prim->attrs(); if (IsInBlackList(prim)) { MS_LOG(INFO) << "Parallel don't care node: " << prim->name(); return false; @@ -653,13 +654,6 @@ LossNodeInfo GetLossNodeInfo(const AnfNodePtr& loss_node) { MS_EXCEPTION_IF_NULL(pre_node); LossNodeInfo node_info; - // return -> cast - auto pre_cnode = pre_node->cast(); - MS_EXCEPTION_IF_NULL(pre_cnode); - auto pre_prim = GetValueNode(pre_cnode->input(0)); - if (pre_prim->name() == CAST && pre_cnode->operator_info() == nullptr) { - pre_node = pre_cnode->input(1); - } // return -> cast auto pre_cnode = pre_node->cast(); @@ -1978,6 +1972,7 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { PrimitivePtr current_prim = current_value->value()->cast(); MS_EXCEPTION_IF_NULL(current_prim); <<<<<<< HEAD +<<<<<<< HEAD ======= >>>>>>> fix_cast_bug @@ -1988,6 +1983,8 @@ CNodePtr FindLossCNode(const FuncGraphPtr& func_graph) { current_prim = GetValueNode(pre_cnode->input(0)); } +======= +>>>>>>> 回退 'Pull Request !17 : [AutoParallel]Fix bug in the case of two cast' // notice: the GetNext op has not input if (INVALID_LOSS_OPS.find(current_prim->name()) != INVALID_LOSS_OPS.end()) { MS_LOG(INFO) << "The loss is: " << current_prim->name(); diff --git a/tests/ut/python/parallel/test_element_wise_function.py b/tests/ut/python/parallel/test_element_wise_function.py index 2eb3a22ed2..0c65593d6a 100644 --- a/tests/ut/python/parallel/test_element_wise_function.py +++ b/tests/ut/python/parallel/test_element_wise_function.py @@ -268,32 +268,3 @@ def test_cast_before_mirror3(): y = Tensor(np.ones([32, 64]), dtype=ms.float16) b = Tensor(np.ones([64, 64]), dtype=ms.float32) _executor.compile(net, x, y, b) - - -def test_mul_two_cast(): - class Net(nn.Cell): - def __init__(self, strategy1, strategy2, strategy3): - super().__init__() - self.mul = P.Mul().set_strategy(strategy1) - self.mul2 = P.Mul().set_strategy(strategy2) - self.cast = P.Cast().set_strategy(strategy3) - self.cast2 = P.Cast().set_strategy(strategy3) - - def construct(self, x, y, b): - out = self.mul(x, y) - out = self.mul2(out, b) - out = self.cast(out, ms.int32) - out = self.cast2(out, ms.bool_) - return out - - context.set_auto_parallel_context(device_num=8, global_rank=0) - strategy1 = ((2, 2), (2, 2)) - strategy2 = ((8, 1), (8, 1)) - strategy3 = ((8, 1), ) - net = GradWrap(Net(strategy1, strategy2, strategy3)) - context.set_auto_parallel_context(parallel_mode="semi_auto_parallel") - - x = Tensor(np.ones([128, 32]), dtype=ms.float32) - y = Tensor(np.ones([128, 32]), dtype=ms.float32) - b = Tensor(np.ones([128, 32]), dtype=ms.float32) - _executor.compile(net, x, y, b)