From 9708e58259f18db3b6932d64bff6d5dcb2c91db7 Mon Sep 17 00:00:00 2001 From: kingfo Date: Thu, 11 Jun 2020 17:06:01 +0800 Subject: [PATCH] fix TupleToArray & Cast operator issue --- mindspore/common/parameter.py | 19 ++---- mindspore/nn/cell.py | 8 ++- mindspore/ops/operations/array_ops.py | 9 ++- tests/st/pynative/test_ops.py | 31 +++++++++ tests/ut/python/nn/optim/test_optimizer.py | 4 +- .../nn/test_parameter_operation.py | 67 +++++++++++++++++++ .../python/pynative_mode/test_parse_method.py | 8 +++ 7 files changed, 126 insertions(+), 20 deletions(-) create mode 100644 tests/st/pynative/test_ops.py create mode 100644 tests/ut/python/pynative_mode/nn/test_parameter_operation.py diff --git a/mindspore/common/parameter.py b/mindspore/common/parameter.py index a943cbcf63..30f4bd7e48 100644 --- a/mindspore/common/parameter.py +++ b/mindspore/common/parameter.py @@ -15,7 +15,7 @@ """Parameter for cell.""" import numbers -from copy import copy, deepcopy +from copy import copy from mindspore import context from . import dtype as mstype from .initializer import initializer, Initializer @@ -191,25 +191,16 @@ class Parameter: return self.default_input def __add__(self, other): - res = deepcopy(self) - res.default_input = res.default_input + other - return res + return self.default_input + other def __sub__(self, other): - res = deepcopy(self) - res.default_input = res.default_input - other - return res + return self.default_input - other def __mul__(self, other): - res = deepcopy(self) - default_input = res.default_input * other - res.default_input = Tensor(default_input.asnumpy().copy()) - return res + return self.default_input * other def __truediv__(self, other): - res = deepcopy(self) - res.default_input = res.default_input / other - return res + return self.default_input / other def __setitem__(self, index, value): return self diff --git a/mindspore/nn/cell.py b/mindspore/nn/cell.py index 4024f8315f..0377e9dcc3 100755 --- a/mindspore/nn/cell.py +++ b/mindspore/nn/cell.py @@ -202,6 +202,7 @@ class Cell: if context.get_context("mode") == context.GRAPH_MODE: out = self.compile_and_run(*inputs) return out + self.init_parameters_data() orign_grad = [] if self.requires_grad is True: _pynative_exec.set_grad_flag(True) @@ -254,9 +255,12 @@ class Cell: value.update_parameters_name(name + '.') cells[name] = value elif params and name in params: - if value is not None: + if isinstance(value, Tensor) and self._params[name] is not None: + self._params[name].set_parameter_data(value) + elif value is not None: raise TypeError("Expected type in (Parameter, ParameterTuple), but got {}.".format(type(value))) - self.insert_param_to_cell(name, None) + else: + self.insert_param_to_cell(name, None) elif cells and name in cells: if value is not None: raise TypeError("Expected type is cell, but got {}.".format(type(value))) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index bd395d94f5..5f11625563 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -30,7 +30,7 @@ from ...common import dtype as mstype from ...common.tensor import Tensor from ..operations.math_ops import _infer_shape_reduce from .._utils import get_concat_offset -from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register +from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register, _run_op from ..._c_expression import signature_rw as sig_rw from ..._c_expression import signature_kind as sig_kind from ..._c_expression import signature_dtype as sig_dtype @@ -990,9 +990,14 @@ class TupleToArray(PrimitiveWithInfer): ret = np.array(x, np.int32) else: ret = np.array(x, np.float32) - return Tensor(ret) + def __call__(self, x): + args = list() + if isinstance(x, range): + args.append(tuple(x)) + return _run_op(self, self.name, args) + class ScalarToArray(PrimitiveWithInfer): """ diff --git a/tests/st/pynative/test_ops.py b/tests/st/pynative/test_ops.py new file mode 100644 index 0000000000..3cec24fb10 --- /dev/null +++ b/tests/st/pynative/test_ops.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import numpy as np + +import mindspore as ms +import mindspore.ops.operations as P +from mindspore import context, Tensor + + +def test_cast(): + """ tests cast for same dtype""" + context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend") + input_np = np.random.randn(2, 3, 4, 5).astype(np.float32) + input_x = Tensor(input_np) + type_dst = ms.float32 + cast = P.Cast() + result = cast(input_x, type_dst) + assert result.dtype() == type_dst diff --git a/tests/ut/python/nn/optim/test_optimizer.py b/tests/ut/python/nn/optim/test_optimizer.py index 0a12f11365..70b79e97d7 100644 --- a/tests/ut/python/nn/optim/test_optimizer.py +++ b/tests/ut/python/nn/optim/test_optimizer.py @@ -52,11 +52,11 @@ class TestAdam(): use_nesterov=False, weight_decay=0.0, loss_scale=1.0) def test_construct(self): - with pytest.raises(TypeError): + with pytest.raises(RuntimeError): gradient = Tensor(np.zeros([1, 2, 3])) adam = Adam(params, learning_rate=1e-3, beta1=0.9, beta2=0.999, eps=1e-8, use_locking=False, use_nesterov=False, weight_decay=0.0, loss_scale=1.0) - adam.construct(gradient) + adam(gradient) class TestSGD(): diff --git a/tests/ut/python/pynative_mode/nn/test_parameter_operation.py b/tests/ut/python/pynative_mode/nn/test_parameter_operation.py new file mode 100644 index 0000000000..7f609361f0 --- /dev/null +++ b/tests/ut/python/pynative_mode/nn/test_parameter_operation.py @@ -0,0 +1,67 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test_tensor_operation """ +import numpy as np + +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore import context + + +def setup_module(module): + context.set_context(mode=context.PYNATIVE_MODE) + +def test_parameter_add(): + x = Parameter(Tensor(np.ones((3, 3)).astype(np.float32)), name="ref") + y = Tensor(np.ones((3, 3)).astype(np.float32)) + expect = np.ones((3, 3)).astype(np.float32) * 2 + z = x + y + assert np.allclose(z.asnumpy(), expect) + +def test_parameter_sub(): + x = Parameter(Tensor(np.ones((3, 3)).astype(np.float32) * 2), name="ref") + y = Tensor(np.ones((3, 3)).astype(np.float32)) + expect = np.ones((3, 3)).astype(np.float32) + z = x - y + assert np.allclose(z.asnumpy(), expect) + +def test_parameter_mul(): + x = Parameter(Tensor(np.ones((3, 3)).astype(np.float32) * 2), name="ref") + y = Tensor(np.ones((3, 3)).astype(np.float32) * 2) + expect = np.ones((3, 3)).astype(np.float32) * 4 + z = x * y + assert np.allclose(z.asnumpy(), expect) + +def test_parameter_div(): + x = Parameter(Tensor(np.ones((3, 3)).astype(np.float32) * 8), name="ref") + y = Tensor(np.ones((3, 3)).astype(np.float32) * 2) + expect = np.ones((3, 3)).astype(np.float32) * 4 + z = x / y + assert np.allclose(z.asnumpy(), expect) + +class ParameterNet(nn.Cell): + def __init__(self): + super(ParameterNet, self).__init__() + self.weight = Parameter(Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], np.float32)), name="ref") + + def construct(self, x): + self.weight = x + +def test_parameter_assign(): + """test parameter assign with tensor""" + input_x = Tensor(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 8.0]], np.float32)) + net = ParameterNet() + net(input_x) + assert np.allclose(net.weight.data.asnumpy(), input_x.asnumpy()) diff --git a/tests/ut/python/pynative_mode/test_parse_method.py b/tests/ut/python/pynative_mode/test_parse_method.py index 8a648ac90e..f189b825e9 100644 --- a/tests/ut/python/pynative_mode/test_parse_method.py +++ b/tests/ut/python/pynative_mode/test_parse_method.py @@ -31,6 +31,7 @@ from mindspore.common.api import ms_function from mindspore.common.tensor import Tensor from mindspore.ops.composite import core from mindspore.ops.primitive import constexpr +from mindspore.ops import functional as F from ..ut_filter import non_graph_engine @@ -427,3 +428,10 @@ def test_expr(): def tuple_len(x): assert len(x) == 2 tuple_len(a) + + +def test_tuple_to_array(): + """ test range tuple to array """ + range_x = range(10) + res = F.tuple_to_array(range_x) + print(res)