[auto-monad] Fix backend control flow bug found by igamma test

pull/12663/head
He Wei 4 years ago
parent e1980d2386
commit c837fb25a2

@ -305,6 +305,9 @@ class AscendAutoMonadConverter {
} }
if (return_label_ != kNoLabel) { if (return_label_ != kNoLabel) {
(void)LabelGoto(return_label_); (void)LabelGoto(return_label_);
} else {
// Clear end goto if return label not set.
kernel_graph_->set_end_goto(nullptr);
} }
} }
} }

@ -274,6 +274,25 @@ std::vector<CNodePtr> KernelGraph::SortStartLabelAndEndGoto() {
continue; continue;
} }
//
// Re-order:
// u = LabelGoto(...)
// x = Mul(...)
// LabelSet(u)
// To:
// u = LabelGoto(...)
// LabelSet(u)
// x = Mul(...)
// This prevent Mul be skipped.
//
if (IsPrimitiveCNode(node, prim::kPrimLabelSet) && (re_order.back() != node->input(1))) {
auto iter = std::find(re_order.rbegin() + 1, re_order.rend(), node->input(1));
if (iter != re_order.rend()) {
re_order.insert(iter.base(), node);
continue;
}
}
re_order.push_back(node); re_order.push_back(node);
} }
if (end_goto_ != nullptr) { if (end_goto_ != nullptr) {

@ -15,6 +15,7 @@
import os import os
import tempfile import tempfile
import pytest import pytest
import scipy
import numpy as np import numpy as np
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.ops.operations as P import mindspore.ops.operations as P
@ -395,3 +396,24 @@ def test_summary():
event = summary_writer.read_event() event = summary_writer.read_event()
tags = set(value.tag for value in event.summary.value) tags = set(value.tag for value in event.summary.value)
assert tags == {'tensor', 'histogram', 'scalar', 'image'} assert tags == {'tensor', 'histogram', 'scalar', 'image'}
@pytest.mark.level0
@pytest.mark.platform_arm_ascend_training
@pytest.mark.platform_x86_ascend_training
@pytest.mark.env_onecard
def test_igamma():
class IGammaTest(nn.Cell):
def __init__(self):
super().__init__()
self.igamma = nn.IGamma()
def construct(self, x, a):
return self.igamma(a=a, x=x)
x = 4.22
a = 2.29
net = IGammaTest()
out = net(Tensor(x, mstype.float32), Tensor(a, mstype.float32))
expect = scipy.special.gammainc(a, x)
assert np.allclose(out.asnumpy(), expect, rtol=1e-5, atol=1e-5, equal_nan=True)

Loading…
Cancel
Save