From 157710ca0fb5573261f8cb0ee81d733d1a8a3737 Mon Sep 17 00:00:00 2001 From: Wei Luning Date: Wed, 29 Apr 2020 10:55:06 +0800 Subject: [PATCH] bugfix* fix bug in output tuple of tuple.* check kRWWrite input no-variable* input x of ScatterNdUpdate should be a parameter node --- example/resnet101_imagenet2012/train.py | 1 - .../ccsrc/operator/composite/do_signature.cc | 2 ++ mindspore/common/api.py | 9 ++++---- mindspore/nn/optim/adam.py | 17 +++++++------- mindspore/nn/optim/lamb.py | 22 +++++++++--------- mindspore/ops/operations/array_ops.py | 7 ++++++ tests/st/ops/gpu/test_assign_add_op.py | 23 +++++++++++-------- tests/st/ops/gpu/test_assign_op.py | 15 ++++++------ tests/ut/python/ops/test_math_ops.py | 12 ++++++++++ tests/ut/python/ops/test_momentum.py | 3 +-- tests/ut/python/ops/test_nn_ops.py | 12 +++++----- tests/ut/python/ops/test_ops.py | 20 ++++++---------- 12 files changed, 80 insertions(+), 63 deletions(-) diff --git a/example/resnet101_imagenet2012/train.py b/example/resnet101_imagenet2012/train.py index 3d0a23f93a..6a89a212ca 100755 --- a/example/resnet101_imagenet2012/train.py +++ b/example/resnet101_imagenet2012/train.py @@ -14,7 +14,6 @@ # ============================================================================ """train_imagenet.""" import os -import math import argparse import random import numpy as np diff --git a/mindspore/ccsrc/operator/composite/do_signature.cc b/mindspore/ccsrc/operator/composite/do_signature.cc index c3fe45a48a..1098ed1520 100644 --- a/mindspore/ccsrc/operator/composite/do_signature.cc +++ b/mindspore/ccsrc/operator/composite/do_signature.cc @@ -195,6 +195,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param}); } // If sig is SignatureEnumRW::kRWRef, not do anything. + } else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) { + MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter."; } // add cast op here if (assign_source != nullptr && sig != SignatureEnumRW::kRWWrite) { diff --git a/mindspore/common/api.py b/mindspore/common/api.py index 455e7a7f4f..3710e40996 100644 --- a/mindspore/common/api.py +++ b/mindspore/common/api.py @@ -70,12 +70,11 @@ def _wrap_func(fn): def _convert_data(data): if isinstance(data, Tensor) and not isinstance(data, MsTensor): return MsTensor(data) + if isinstance(data, tuple): + return tuple(_convert_data(x) for x in data) + if isinstance(data, list): + return list(_convert_data(x) for x in data) return data - - if isinstance(results, tuple): - return tuple(_convert_data(x) for x in results) - if isinstance(results, list): - return list(_convert_data(x) for x in results) return _convert_data(results) return wrapper diff --git a/mindspore/nn/optim/adam.py b/mindspore/nn/optim/adam.py index 87c46380f6..1a386556d9 100755 --- a/mindspore/nn/optim/adam.py +++ b/mindspore/nn/optim/adam.py @@ -57,21 +57,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad op_reshape = P.Reshape() op_shape = P.Shape() - param = op_cast(param, mstype.float32) - m = op_cast(m, mstype.float32) - v = op_cast(v, mstype.float32) - gradient = op_cast(gradient, mstype.float32) + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) - next_m = op_mul(beta1, m) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient) + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32) - next_v = op_mul(beta2, v) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient)) + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) + - beta2, op_square(gradient_fp32)) update = next_m / (op_sqrt(next_v) + eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param) + update = update + op_mul(weight_decay_tensor, param_fp32) update_with_lr = op_mul(lr, update) - next_param = param - op_reshape(update_with_lr, op_shape(param)) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_v = F.depend(next_v, F.assign(param, next_param)) next_v = F.depend(next_v, F.assign(m, next_m)) diff --git a/mindspore/nn/optim/lamb.py b/mindspore/nn/optim/lamb.py index cbeb6fa674..e026b1c560 100755 --- a/mindspore/nn/optim/lamb.py +++ b/mindspore/nn/optim/lamb.py @@ -67,23 +67,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para op_fill = P.Fill() op_dtype = P.DType() - param = op_cast(param, mstype.float32) - m = op_cast(m, mstype.float32) - v = op_cast(v, mstype.float32) - gradient = op_cast(gradient, mstype.float32) + param_fp32 = op_cast(param, mstype.float32) + m_fp32 = op_cast(m, mstype.float32) + v_fp32 = op_cast(v, mstype.float32) + gradient_fp32 = op_cast(gradient, mstype.float32) - next_m = op_mul(beta1, m) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient) + next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32) - next_v = op_mul(beta2, v) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient)) + next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32)) next_mm = next_m / (op_cast(num_one, mstype.float32) - op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) next_vv = next_v / (op_cast(num_one, mstype.float32) - op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) - w_norm = op_norm(param) - g_norm = op_norm(gradient) + w_norm = op_norm(param_fp32) + g_norm = op_norm(gradient_fp32) - g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param) + g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32) zeros = F.zeros_like_tensor(w_norm) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) trust_ratio = op_select( @@ -95,11 +95,11 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para update = next_mm / (op_sqrt(next_vv) + eps) if decay_flag: - update = update + op_mul(weight_decay_tensor, param) + update = update + op_mul(weight_decay_tensor, param_fp32) update_with_lr = op_mul(op_mul(trust_ratio, lr), update) - next_param = param - op_reshape(update_with_lr, op_shape(param)) + next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32)) next_v = F.depend(next_v, F.assign(param, next_param)) next_v = F.depend(next_v, F.assign(m, next_m)) diff --git a/mindspore/ops/operations/array_ops.py b/mindspore/ops/operations/array_ops.py index b69b083e03..f611bc9617 100644 --- a/mindspore/ops/operations/array_ops.py +++ b/mindspore/ops/operations/array_ops.py @@ -24,6 +24,8 @@ import itertools import numbers import numpy as np +from ..._c_expression import signature_rw as sig_rw +from ..._c_expression import signature_kind as sig_kind from ..._checkparam import Validator as validator from ..._checkparam import Rel from ...common import dtype as mstype @@ -1965,6 +1967,11 @@ class ScatterNdUpdate(PrimitiveWithInfer): >>> op = P.ScatterNdUpdate() >>> output = op(input_x, indices, update) """ + __mindspore_signature__ = ( + ('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD), + ('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD), + ('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD) + ) @prim_attr_register def __init__(self, use_locking=True): diff --git a/tests/st/ops/gpu/test_assign_add_op.py b/tests/st/ops/gpu/test_assign_add_op.py index b021a32f32..4c95177fb6 100644 --- a/tests/st/ops/gpu/test_assign_add_op.py +++ b/tests/st/ops/gpu/test_assign_add_op.py @@ -14,19 +14,20 @@ # ============================================================================ import pytest -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore.ops import operations as P import mindspore.nn as nn import numpy as np import mindspore.context as context class AssignAdd(nn.Cell): - def __init__( self): + def __init__(self, value): super(AssignAdd, self).__init__() + self.var = Parameter(value, name="var") self.add = P.AssignAdd() - def construct(self, x, y): - res = self.add(x, y) + def construct(self, y): + res = self.add(self.var, y) return res @pytest.mark.level0 @@ -58,15 +59,17 @@ def test_assign_add(): y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') - add = AssignAdd() - output1 = add(x1, y1) + add = AssignAdd(x1) + output1 = add(y1) assert (output1.asnumpy() == expect1).all() - output2 = add(output1, y1) + add = AssignAdd(output1) + output2 = add(y1) assert (output2.asnumpy() == expect2).all() context.set_context(mode=context.GRAPH_MODE, device_target='GPU') - add = AssignAdd() - output1 = add(x2, y2) + add = AssignAdd(x2) + output1 = add(y2) assert (output1.asnumpy() == expect1).all() - output2 = add(output1, y2) + add = AssignAdd(output1) + output2 = add(y2) assert (output2.asnumpy() == expect2).all() diff --git a/tests/st/ops/gpu/test_assign_op.py b/tests/st/ops/gpu/test_assign_op.py index 4cf730d763..f1fb908268 100644 --- a/tests/st/ops/gpu/test_assign_op.py +++ b/tests/st/ops/gpu/test_assign_op.py @@ -14,7 +14,7 @@ # ============================================================================ import pytest -from mindspore import Tensor +from mindspore import Tensor, Parameter from mindspore.ops import operations as P import mindspore.nn as nn import numpy as np @@ -22,12 +22,13 @@ import mindspore.context as context class Net(nn.Cell): - def __init__(self): + def __init__(self, value): super(Net, self).__init__() + self.var = Parameter(value, name="var") self.assign = P.Assign() - def construct(self, var, value): - return self.assign(var, value) + def construct(self, value): + return self.assign(self.var, value) x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) value = np.array([[1, 2], [3, 4.0]]).astype(np.float32) @@ -37,13 +38,13 @@ value = np.array([[1, 2], [3, 4.0]]).astype(np.float32) @pytest.mark.env_onecard def test_assign(): context.set_context(mode=context.GRAPH_MODE, device_target="GPU") - assign = Net() var = Tensor(x) - output = assign(var, Tensor(value)) + assign = Net(var) + output = assign(Tensor(value)) error = np.ones(shape=[2, 2]) * 1.0e-6 diff1 = output.asnumpy() - value - diff2 = var.asnumpy() - value + diff2 = assign.var.default_input.asnumpy() - value assert np.all(diff1 < error) assert np.all(-diff1 < error) assert np.all(diff2 < error) diff --git a/tests/ut/python/ops/test_math_ops.py b/tests/ut/python/ops/test_math_ops.py index a4a645a7ef..7ada847aac 100755 --- a/tests/ut/python/ops/test_math_ops.py +++ b/tests/ut/python/ops/test_math_ops.py @@ -341,6 +341,15 @@ class SignNet(nn.Cell): def construct(self, x): return self.sign(x) +class AssignAdd(nn.Cell): + def __init__(self): + super().__init__() + self.op = P.AssignAdd() + self.inputdata = Parameter(initializer(1, [1], ms.float32), name="global_step") + + def construct(self, input_): + self.inputdata = input_ + return self.op(self.inputdata, input_) test_case_math_ops = [ ('MatMulGrad', { @@ -413,6 +422,9 @@ raise_set = [ ('StridedSlice_4_Error', { 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), 'desc_inputs': [0]}), + ('AssignAdd_Error', { + 'block': (P.AssignAdd(), {'exception': TypeError}), + 'desc_inputs': [[1]]}), ] diff --git a/tests/ut/python/ops/test_momentum.py b/tests/ut/python/ops/test_momentum.py index 3334f1670a..28b9637015 100644 --- a/tests/ut/python/ops/test_momentum.py +++ b/tests/ut/python/ops/test_momentum.py @@ -38,8 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum, gradient, variable, moment): """ tensor_run_opt """ success = True - new_weight = opt(gradient, moment, variable, - learning_rate, momentum) + new_weight = opt(variable, moment, learning_rate, gradient, momentum) success = F.depend(success, F.assign(variable, new_weight)) return success diff --git a/tests/ut/python/ops/test_nn_ops.py b/tests/ut/python/ops/test_nn_ops.py index ab6f31095d..5038ee28a0 100644 --- a/tests/ut/python/ops/test_nn_ops.py +++ b/tests/ut/python/ops/test_nn_ops.py @@ -446,12 +446,6 @@ test_cases = [ 'desc_inputs': [[128, 32, 32, 64]], 'desc_bprop': [[128, 32, 32, 64]], }), - ('ApplyMomentum', { - 'block': P.ApplyMomentum(), - 'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]], - 'desc_bprop': [[128, 32, 32, 64]], - 'skip': ['backward'] - }), ('ScalarSummary', { 'block': ScalarSummaryNet(), 'desc_inputs': [2.2], @@ -515,6 +509,12 @@ test_cases = [ ] test_cases_for_verify_exception = [ + ('ApplyMomentum_Error', { + 'block': (P.ApplyMomentum(), {'exception': TypeError}), + 'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]], + 'desc_bprop': [[128, 32, 32, 64]], + 'skip': ['backward'] + }), ('Conv2d_ValueError_1', { 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}), 'desc_inputs': [0], diff --git a/tests/ut/python/ops/test_ops.py b/tests/ut/python/ops/test_ops.py index 68ff816fb3..23bf7da4b4 100755 --- a/tests/ut/python/ops/test_ops.py +++ b/tests/ut/python/ops/test_ops.py @@ -674,12 +674,6 @@ test_case_nn_ops = [ 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'skip': ['backward']}), - ('ApplyMomentum', { - 'block': P.ApplyMomentum(), - 'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64], - [32, 32, 64], [32, 32, 64], [32, 32, 64]], - 'desc_bprop': [[128, 32, 32, 64]], - 'skip': ['backward']}), ('TopK', { 'block': P.TopK(), 'desc_const': [5], @@ -1113,12 +1107,6 @@ test_case_other_ops = [ 'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)), Tensor(np.ones((2, 4), np.int32))), 'desc_bprop': [[2]]}), - ('ScatterNdUpdate', { - 'block': P.ScatterNdUpdate(), - 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), - Tensor(np.ones((2, 2), np.int32)), - Tensor(np.ones((2,), np.float32))), - 'desc_bprop': [[2, 3]]}), ('ScatterNd', { 'block': P.ScatterNd(), 'desc_const': [(3, 3)], @@ -1178,7 +1166,7 @@ import mindspore.context as context @non_graph_engine @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) def test_exec(): - context.set_context(mode=context.GRAPH_MODE) + context.set_context(mode=context.GRAPH_MODE, save_graphs=True) return test_exec_case @@ -1207,6 +1195,12 @@ raise_set = [ 'block': (NetForFlatten0D(), {'exception': ValueError}), 'desc_inputs': [Tensor(np.array(0).astype(np.int32))], 'desc_bprop': [Tensor(np.array(0).astype(np.int32))]}), + ('ScatterNdUpdate', { + 'block': (P.ScatterNdUpdate(), {'exception': TypeError}), + 'desc_inputs': (Tensor(np.ones((2, 3), np.float32)), + Tensor(np.ones((2, 2), np.int32)), + Tensor(np.ones((2,), np.float32))), + 'desc_bprop': [[2, 3]]}), ]