From 94c9019d8e57a25b2c1bf0ea1e7c23f3e27c0b75 Mon Sep 17 00:00:00 2001 From: buxue Date: Tue, 2 Jun 2020 20:09:35 +0800 Subject: [PATCH] restricting modify non_Parameter class members --- mindspore/ccsrc/pipeline/parse/parse.cc | 20 ++++++++--- tests/ut/python/dtype/test_dictionary.py | 9 +++-- tests/ut/python/dtype/test_list.py | 44 ++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 9 deletions(-) diff --git a/mindspore/ccsrc/pipeline/parse/parse.cc b/mindspore/ccsrc/pipeline/parse/parse.cc index c6e5d3713a..24785fe482 100644 --- a/mindspore/ccsrc/pipeline/parse/parse.cc +++ b/mindspore/ccsrc/pipeline/parse/parse.cc @@ -1175,11 +1175,11 @@ void Parser::HandleAssignClassMember(const FunctionBlockPtr &block, const py::ob auto filename = location[0].cast(); auto line_no = location[1].cast(); // Now only support the self.xxx = yyy, where self.xxx must be a defined Parameter type - if (!py::hasattr(ast()->obj(), attr_name.c_str())) { + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but not defined, at " << filename << ":" << line_no; } - auto obj = ast()->obj().attr(attr_name.c_str()); + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); auto obj_type = obj.attr("__class__").attr("__name__"); if (!py::hasattr(obj, "__parameter__")) { MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" @@ -1205,8 +1205,18 @@ void Parser::HandleAssignSubscript(const FunctionBlockPtr &block, const py::obje // getitem apply should return the sequence data structure itself std::string var_name = ""; if (ast_->IsClassMember(value_obj)) { - var_name = "self."; - (void)var_name.append(value_obj.attr("attr").cast()); + std::string attr_name = value_obj.attr("attr").cast(); + var_name = "self." + attr_name; + if (!py::hasattr(ast()->obj(), common::SafeCStr(attr_name))) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' was not defined in the class '__init__' function."; + } + auto obj = ast()->obj().attr(common::SafeCStr(attr_name)); + auto obj_type = obj.attr("__class__").attr("__name__"); + if (!py::hasattr(obj, "__parameter__")) { + MS_EXCEPTION(TypeError) << "'" << var_name << "' should be a Parameter, but got '" + << py::str(obj).cast() << "' with type '" + << py::str(obj_type).cast() << "'."; + } } else { var_name = value_obj.attr("id").cast(); } @@ -1231,7 +1241,7 @@ void Parser::WriteAssignVars(const FunctionBlockPtr &block, const py::object &ta } } -// process a assign statement , such as a =b, a,b = tup +// process a assign statement, such as a =b, a,b = tup FunctionBlockPtr Parser::ParseAssign(const FunctionBlockPtr &block, const py::object &node) { MS_LOG(DEBUG) << "Process ast assgin"; py::object value_object = python_adapter::GetPyObjAttr(node, "value"); diff --git a/tests/ut/python/dtype/test_dictionary.py b/tests/ut/python/dtype/test_dictionary.py index 14033873f7..052372dd39 100644 --- a/tests/ut/python/dtype/test_dictionary.py +++ b/tests/ut/python/dtype/test_dictionary.py @@ -17,6 +17,7 @@ @Desc : test_dictionary """ import numpy as np +import pytest from mindspore import Tensor, context from mindspore.nn import Cell @@ -89,7 +90,9 @@ def test_dict_set_or_get_item(): return ret net = DictNet() - assert net() == (88, 99, 4, 5, 6) + with pytest.raises(TypeError) as ex: + net() + assert "'self.dict_' should be a Parameter" in str(ex.value) def test_dict_set_or_get_item_2(): @@ -135,7 +138,9 @@ def test_dict_set_or_get_item_3(): return self.dict_["x"] net = DictNet() - assert net() == Tensor(np.ones([4, 2, 3], np.float32)) + with pytest.raises(TypeError) as ex: + net() + assert "'self.dict_' should be a Parameter" in str(ex.value) def test_dict_set_item(): diff --git a/tests/ut/python/dtype/test_list.py b/tests/ut/python/dtype/test_list.py index f9ddb1a16c..971bf85bef 100644 --- a/tests/ut/python/dtype/test_list.py +++ b/tests/ut/python/dtype/test_list.py @@ -15,6 +15,7 @@ import functools import numpy as np +import pytest import mindspore.nn as nn import mindspore.context as context from mindspore import Tensor @@ -24,6 +25,7 @@ from tests.mindspore_test_framework.mindspore_test import mindspore_test from tests.mindspore_test_framework.pipeline.forward.compile_forward \ import pipeline_for_compile_forward_ge_graph_for_case_by_case_config +context.set_context(mode=context.GRAPH_MODE) def test_list_equal(): class Net(nn.Cell): @@ -109,7 +111,7 @@ def test_list_append(): assert net(x, y) == y -def test_list_append_2(): +def test_class_member_list_append(): class Net(nn.Cell): def __init__(self, z: list): super(Net, self).__init__() @@ -129,7 +131,45 @@ def test_list_append_2(): y = Tensor(np.zeros([3, 4, 5], np.int32)) z = [[1, 2], 3] net = Net(z) - assert net(x, y) == x + with pytest.raises(TypeError) as ex: + net(x, y) + assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value) + + +def test_class_member_not_defined(): + class Net(nn.Cell): + def __init__(self, z: list): + super(Net, self).__init__() + self.z = z + + def construct(self, x, y): + self.x[0] = 9 + return self.x + + z = [[1, 2], 3] + net = Net(z) + with pytest.raises(TypeError) as ex: + net() + assert "'self.x' was not defined in the class '__init__' function." in str(ex.value) + + +def test_change_list_element(): + class Net(nn.Cell): + def __init__(self, z: list): + super(Net, self).__init__() + self.z = z + + def construct(self, x, y): + self.z[0] = x + return self.z[0] + + x = Tensor(np.ones([6, 8, 10], np.int32)) + y = Tensor(np.zeros([3, 4, 5], np.int32)) + z = [[1, 2], 3] + net = Net(z) + with pytest.raises(TypeError) as ex: + net(x, y) + assert "'self.z' should be a Parameter, but got '[[1, 2], 3]' with type 'list'." in str(ex.value) class ListOperate(nn.Cell):