From 6ccc4379b4b5fde43a6efa7f1641577ce0cd73fc Mon Sep 17 00:00:00 2001 From: buxue Date: Fri, 29 Jan 2021 18:18:14 +0800 Subject: [PATCH] do not broaden scalar --- .../jit/static_analysis/program_specialize.cc | 6 ++--- mindspore/core/abstract/abstract_value.cc | 2 +- mindspore/core/abstract/prim_others.cc | 9 +++++--- tests/ut/cpp/abstract/utils_test.cc | 12 ++-------- tests/ut/cpp/optimizer/lib_test.cc | 7 ++++-- .../cpp/pipeline/static_analysis/data_test.cc | 4 ++-- tests/ut/python/ops/test_control_ops.py | 21 ------------------ tests/ut/python/ops/test_tensor_slice.py | 22 +++++++++++++++++++ 8 files changed, 41 insertions(+), 42 deletions(-) diff --git a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc index 60760d1bae..d8c175c74f 100644 --- a/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc +++ b/mindspore/ccsrc/pipeline/jit/static_analysis/program_specialize.cc @@ -544,7 +544,7 @@ void FuncGraphSpecializer::ProcessCNode(const CNodePtr &new_node) { } if (CanSpecializeNode(func)) { - // for primitive node , we build the primitive node with infered attributes in the first pass + // for primitive node , we build the primitive node with inferred attributes in the first pass // so we do not build replaced node again here in second pass if (IsValueNode(func)) { new_inputs[0] = func; @@ -666,14 +666,14 @@ AnfNodePtr FuncGraphSpecializer::BuildPossibleValueNode(const AnfNodePtr &origin AbstractFunctionPtr abs = dyn_cast(ival); if (abs != nullptr) { - // Cannot build a determinstic ValueNode if there are multiple possible AbstractFunction. + // Cannot build a deterministic ValueNode if there are multiple possible AbstractFunction. if (abs->isa()) { return nullptr; } ValuePtr value = nullptr; if (abs->isa()) { auto real_fn = dyn_cast(abs); - // for primitive, check if the attribute is the same with cnode infererd attribute ,if not, clone a new one + // for primitive, check if the attribute is the same with cnode inferred attribute, if not, clone a new one if (attrs != nullptr) { value = BuildPrimtiveValueWithAttributes(real_fn->prim(), attrs); } else { diff --git a/mindspore/core/abstract/abstract_value.cc b/mindspore/core/abstract/abstract_value.cc index 896ac2a78b..ca9a37b570 100644 --- a/mindspore/core/abstract/abstract_value.cc +++ b/mindspore/core/abstract/abstract_value.cc @@ -88,7 +88,7 @@ std::string AbstractBase::ToString() const { return buffer.str(); } -AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return AbstractBase::Broaden(config); } +AbstractBasePtr AbstractScalar::Broaden(uint8_t config) const { return Clone(); } AbstractBasePtr AbstractScalar::Join(const AbstractBasePtr &other) { MS_EXCEPTION_IF_NULL(other); diff --git a/mindspore/core/abstract/prim_others.cc b/mindspore/core/abstract/prim_others.cc index a201005889..6adedb7afd 100644 --- a/mindspore/core/abstract/prim_others.cc +++ b/mindspore/core/abstract/prim_others.cc @@ -102,7 +102,7 @@ AbstractBasePtr InferImplMakeRefKey(const AnalysisEnginePtr &, const PrimitivePt ValuePtr name_value = prim->GetAttr("tag"); auto name = name_value->cast(); if (name == nullptr) { - MS_LOG(EXCEPTION) << "MakeRefKey attr tag sould be a String " << name_value->ToString() << "."; + MS_LOG(EXCEPTION) << "MakeRefKey attr tag should be a String " << name_value->ToString() << "."; } auto refkey = std::make_shared(name->value()); if (refkey == nullptr) { @@ -168,6 +168,9 @@ AbstractBasePtr InferImplDepend(const AnalysisEnginePtr &, const PrimitivePtr &p MS_LOG(EXCEPTION) << primitive->name() << " input args size should be at lest 1, but got 0"; } auto depends = args_spec_list[0]->Broaden(); + if (depends->isa()) { + depends->set_value(kAnyValue); + } return depends; } @@ -182,7 +185,7 @@ AbstractBasePtr InferImplControlDepend(const AnalysisEnginePtr &, const Primitiv auto src_size = arg_src->cast()->size(); auto dst_size = arg_src->cast()->size(); if (src_size > 1 && dst_size > 1) { - MS_LOG(EXCEPTION) << "Control depend can not setup operator dependcy relationship from tuple from tuple"; + MS_LOG(EXCEPTION) << "Control depend can not setup operator dependency relationship from tuple from tuple"; } } return std::make_shared(kAnyValue, kBool); @@ -505,7 +508,7 @@ AbstractBasePtr InferImplExpandDims(const AnalysisEnginePtr &, const PrimitivePt auto axis = primitive->GetAttr("axis"); auto value = GetValue(axis); if (value < -(SizeToInt(x_shape.size()) + 1) || value > SizeToInt(x_shape.size())) { - MS_LOG(EXCEPTION) << " axis value shoud be in range [-intput_x.dim-1,input_x.dim], but axis value is" << value + MS_LOG(EXCEPTION) << " axis value should be in range [-input_x.dim-1,input_x.dim], but axis value is" << value << " and input_x.dim is" << x_shape.size(); } if (value < 0) { diff --git a/tests/ut/cpp/abstract/utils_test.cc b/tests/ut/cpp/abstract/utils_test.cc index ea954c0641..ff44c1c040 100644 --- a/tests/ut/cpp/abstract/utils_test.cc +++ b/tests/ut/cpp/abstract/utils_test.cc @@ -32,26 +32,18 @@ TEST_F(TestUtils, test_join) { AbstractBasePtr abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr abs_s2 = FromValue(static_cast(2), false); AbstractBasePtr abs_s_anything = FromValue(static_cast(2), true); + abs_s_anything->set_value(kAnyValue); AbstractBasePtr res_s1 = abs_s1->Join(abs_s2); ASSERT_EQ(*res_s1, *abs_s_anything); - // AbstractTuple join; - std::vector list1 = {1, 2, 3, 4, 5}; - std::vector list2 = {5, 4, 3, 2, 1}; - AbstractBasePtr abs_t1 = FromValue(list1, true); - AbstractBasePtr abs_t2 = FromValue(list2, true); - - AbstractBasePtr res_t1 = abs_t1->Join(abs_t2); - ASSERT_EQ(res_t1, abs_t1); - abs_s1 = FromValue(static_cast(1), false); AbstractBasePtr t1 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t2 = std::make_shared(AbstractBasePtrList({abs_s1, abs_s_anything})); AbstractBasePtr t3 = std::make_shared(AbstractBasePtrList({abs_s_anything, abs_s_anything})); - res_t1 = t1->Join(t2); + AbstractBasePtr res_t1 = t1->Join(t2); ASSERT_EQ(res_t1, t1); res_t1 = t1->Join(t3); diff --git a/tests/ut/cpp/optimizer/lib_test.cc b/tests/ut/cpp/optimizer/lib_test.cc index 3c794f97a8..38881362d5 100644 --- a/tests/ut/cpp/optimizer/lib_test.cc +++ b/tests/ut/cpp/optimizer/lib_test.cc @@ -111,8 +111,11 @@ TEST_F(TestOptLib, test_inline) { // add infer and renormalize std::shared_ptr res = std::make_shared(); AbstractBasePtrList args_spec_list; - AbstractBasePtr abstract_v1 = abstract::FromValue(static_cast(1), true); - AbstractBasePtr abstract_v2 = abstract::FromValue(static_cast(2), true); + tensor::TensorPtr x_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); + tensor::TensorPtr y_tensor = std::make_shared(kFloat32->type_id(), std::vector{2, 3}); + + AbstractBasePtr abstract_v1 = abstract::FromValue(x_tensor, true); + AbstractBasePtr abstract_v2 = abstract::FromValue(y_tensor, true); args_spec_list.push_back(abstract_v1); args_spec_list.push_back(abstract_v2); AnalysisResult result = pipeline::AbstractAnalyze(res, before1, args_spec_list); diff --git a/tests/ut/cpp/pipeline/static_analysis/data_test.cc b/tests/ut/cpp/pipeline/static_analysis/data_test.cc index 5c333ed52f..248bf362bb 100644 --- a/tests/ut/cpp/pipeline/static_analysis/data_test.cc +++ b/tests/ut/cpp/pipeline/static_analysis/data_test.cc @@ -184,7 +184,7 @@ TEST_F(TestData, test_broaden) { AbstractBasePtr s2 = s1->Broaden(); ASSERT_TRUE(*s1->GetTypeTrack() == *s2->GetTypeTrack()); ASSERT_TRUE(*s1->GetValueTrack() == *MakeValue(int1)); - ASSERT_TRUE(s2->GetValueTrack()->isa()); + ASSERT_TRUE(s2->GetValueTrack()->isa()); AbstractFunctionPtr f1 = std::make_shared(std::make_shared(), AnalysisContext::DummyContext()); @@ -196,7 +196,7 @@ TEST_F(TestData, test_broaden) { AbstractList* l2_cast = dynamic_cast(l2.get()); ASSERT_TRUE(l2_cast != nullptr); AbstractBasePtr csr = AbstractJoin(l2_cast->elements()); - ASSERT_TRUE(csr->GetValueTrack()->isa()); + ASSERT_TRUE(csr->GetValueTrack()->isa()); } } // namespace abstract diff --git a/tests/ut/python/ops/test_control_ops.py b/tests/ut/python/ops/test_control_ops.py index 96e31459ab..a7d96a7772 100644 --- a/tests/ut/python/ops/test_control_ops.py +++ b/tests/ut/python/ops/test_control_ops.py @@ -761,27 +761,6 @@ def test_while_scalar(): out = net(x, y) -def test_while_tensor(): - class Net(nn.Cell): - def __init__(self): - super(Net, self).__init__() - self.t = Tensor(np.ones([6, 8, 10], np.int32)) - self.count = Tensor(np.array([10], np.int32)) - - def construct(self, x, y): - i = 0 - t = self.t - while (i < self.count): - t = t + x + y - i = i + 1 - return t - - net = Net() - x = Tensor(np.ones([6, 8, 10], np.int32)) - y = Tensor(np.ones([6, 8, 10], np.int32)) - out = net(x, y) - - def test_large_for_loop(): class Net(nn.Cell): def __init__(self): diff --git a/tests/ut/python/ops/test_tensor_slice.py b/tests/ut/python/ops/test_tensor_slice.py index de8190d0cc..5e1c02aa0e 100644 --- a/tests/ut/python/ops/test_tensor_slice.py +++ b/tests/ut/python/ops/test_tensor_slice.py @@ -20,6 +20,7 @@ from mindspore import Tensor, Parameter from mindspore import context from mindspore import dtype as mstype from mindspore.nn import Cell +from mindspore.ops import operations as P from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config, \ @@ -683,6 +684,27 @@ def test_tensor_assign_bool_index(): net4(Ta, Tensor(u_scalar, mstype.int32)) +def test_trivial_call_function_twice_with_diff_key_value_para(): + class Net(Cell): + def __init__(self): + super(Net, self).__init__() + self.arange = Tensor(np.array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10])) + self.concat = P.Concat(axis=0) + + def compute(self, x, is_decoder): + if is_decoder: + return self.arange[:x] + return self.arange[1:x + 1] + + def construct(self): + result1 = self.compute(7, is_decoder=True) + result2 = self.compute(6, is_decoder=False) + return self.concat((result1, result2)) + + net = Net() + net() + + test_cases = [ ('TensorAssignWithTupleEllipsis2', { 'block': TensorAssignWithTupleEllipsis2(),