diff --git a/mindspore/ccsrc/backend/session/kernel_graph.cc b/mindspore/ccsrc/backend/session/kernel_graph.cc index c05ebc33b2..9d7787d0ef 100644 --- a/mindspore/ccsrc/backend/session/kernel_graph.cc +++ b/mindspore/ccsrc/backend/session/kernel_graph.cc @@ -201,21 +201,17 @@ void KernelGraph::VisitNodeDescendants(const AnfNodePtr &node, std::queue seed_nodes; - UpdateNodeEdgeList(&seed_nodes); + std::queue zero_input_nodes; + UpdateNodeEdgeList(&zero_input_nodes); execution_order_.clear(); std::unordered_set visited_nodes; - std::queue zero_input_nodes; AnfNodePtr last_communication_node = nullptr; std::queue communication_descendants; - while (!seed_nodes.empty() || last_communication_node != nullptr) { + while (!zero_input_nodes.empty() || last_communication_node != nullptr) { // seed nodes first, then visit last all reduce node descendant - if (seed_nodes.empty()) { + if (last_communication_node != nullptr) { VisitNodeDescendants(last_communication_node, &communication_descendants, &visited_nodes); last_communication_node = nullptr; - } else { - zero_input_nodes.push(seed_nodes.front()); - seed_nodes.pop(); } // all reduce node descendant first, then common queue while (!zero_input_nodes.empty() || !communication_descendants.empty()) { @@ -901,14 +897,11 @@ void KernelGraph::UpdateNodeEdgeList(std::queue *seed_nodes) { seed_nodes->push(node); continue; } - auto cnode = dyn_cast(node); + auto cnode = node->cast(); if (cnode == nullptr) { continue; } - auto &inputs = cnode->inputs(); - // We push inputs from right to left, so that them can be evaluated from left to right. - for (auto iter = inputs.rbegin(); iter != inputs.rend(); ++iter) { - auto &input = *iter; + for (auto &input : cnode->inputs()) { PushNoVisitedNode(input, &que, &visited_nodes); AddDependEdge(node, input, 1); } diff --git a/tests/st/auto_monad/test_auto_monad.py b/tests/st/auto_monad/test_auto_monad.py index 11d1d1efd1..54b223401d 100644 --- a/tests/st/auto_monad/test_auto_monad.py +++ b/tests/st/auto_monad/test_auto_monad.py @@ -1427,3 +1427,43 @@ def test_if_cast(): r1 = net(beta1, beta2) expect = Tensor(np.array([3]).astype(np.float32)) np.testing.assert_array_equal(r1.asnumpy(), expect.asnumpy()) + + +@pytest.mark.level0 +@pytest.mark.platform_arm_ascend_training +@pytest.mark.platform_x86_ascend_training +@pytest.mark.env_onecard +def test_multi_add_assign(): + class Net(Cell): + def __init__(self, i1): + super(Net, self).__init__() + self.add = P.Add() + self.sub = P.Sub() + self.mul = P.Mul() + self.assign = P.Assign() + self.p = Parameter(i1, name='para') + + def construct(self, a, d, e): + res1 = self.add(self.add(self.add(self.p, a), a), a) + mul = self.mul(d, e) + self.assign(self.p, mul) + res2 = self.sub(self.p, e) + return res2, res1 + + def numpy_out(p, a, d, e): + res1 = p + a + a + a + res_as = d * e + res2 = d * e - e + return res2, res1, res_as + + p = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32) + i0 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32) + i1 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32) + i2 = (np.abs(np.random.normal(0, 1, [3])) + 1).astype(np.float32) + + net = Net(Tensor(p)) + r2, r1 = net(Tensor(i0), Tensor(i1), Tensor(i2)) + + outputs = [r2.asnumpy(), r1.asnumpy(), net.p.data.asnumpy()] + expects = numpy_out(p, i0, i1, i2) + np.testing.assert_array_equal(outputs, expects) diff --git a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py index d299c0dc8f..fb4faed58c 100644 --- a/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_performance/test_bert_tdt_lossscale.py @@ -229,7 +229,7 @@ def test_bert_performance(): # assertion occurs while the loss value, overflow state or loss_scale value is wrong loss_value = np.array(callback.loss_list) - expect_loss_value = [11.3660, 11.3265, 11.3264] + expect_loss_value = [11.3246, 11.2834, 11.2833] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005) diff --git a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py index 53e49b5882..017821fdac 100644 --- a/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py +++ b/tests/st/networks/models/bert/bert_precision/test_bert_tdt_lossscale.py @@ -229,8 +229,8 @@ def test_bert_precision(enable_graph_kernel=False): expect_loss_value = [12.206627, 11.840489, 11.798470, 11.796345, 11.790964, 12.366766, 11.971539, 12.576565, 12.185522, 12.386192] else: - expect_loss_value = [12.206587, 11.966410, 11.965916, 11.975922, 11.970262, 12.608881, 12.174048, 12.840656, - 12.407923, 12.631133] + expect_loss_value = [12.206587, 11.940709, 11.930911, 11.937369, 11.932178, 12.556069, 12.130172, 12.783402, + 12.359581, 12.578078] print("loss value: {}".format(loss_value)) assert np.allclose(loss_value, expect_loss_value, 0, 0.0005)