bugfix* fix bug in output tuple of tuple.* check kRWWrite input no-variable* input x of ScatterNdUpdate should be a parameter node

pull/849/head
Wei Luning 5 years ago
parent 64c87170fd
commit 157710ca0f

@ -14,7 +14,6 @@
# ============================================================================ # ============================================================================
"""train_imagenet.""" """train_imagenet."""
import os import os
import math
import argparse import argparse
import random import random
import numpy as np import numpy as np

@ -195,6 +195,8 @@ AnfNodePtr BuildNewCNode(const FuncGraphPtr &func_graph, const std::string &func
param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param}); param = func_graph->NewCNode({NewValueNode(prim::kPrimGetRefKey), param});
} }
// If sig is SignatureEnumRW::kRWRef, not do anything. // If sig is SignatureEnumRW::kRWRef, not do anything.
} else if (sig == SignatureEnumRW::kRWWrite && type->type_id() != kObjectTypeRefKey) {
MS_EXCEPTION(TypeError) << "Function " << func_name << "'s input " << i << " should be a Parameter.";
} }
// add cast op here // add cast op here
if (assign_source != nullptr && sig != SignatureEnumRW::kRWWrite) { if (assign_source != nullptr && sig != SignatureEnumRW::kRWWrite) {

@ -70,12 +70,11 @@ def _wrap_func(fn):
def _convert_data(data): def _convert_data(data):
if isinstance(data, Tensor) and not isinstance(data, MsTensor): if isinstance(data, Tensor) and not isinstance(data, MsTensor):
return MsTensor(data) return MsTensor(data)
if isinstance(data, tuple):
return tuple(_convert_data(x) for x in data)
if isinstance(data, list):
return list(_convert_data(x) for x in data)
return data return data
if isinstance(results, tuple):
return tuple(_convert_data(x) for x in results)
if isinstance(results, list):
return list(_convert_data(x) for x in results)
return _convert_data(results) return _convert_data(results)
return wrapper return wrapper

@ -57,21 +57,22 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, param, m, v, grad
op_reshape = P.Reshape() op_reshape = P.Reshape()
op_shape = P.Shape() op_shape = P.Shape()
param = op_cast(param, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
m = op_cast(m, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
v = op_cast(v, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient = op_cast(gradient, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient) next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32) - beta2, op_square(gradient)) next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(F.tuple_to_array((1.0,)), mstype.float32)
- beta2, op_square(gradient_fp32))
update = next_m / (op_sqrt(next_v) + eps) update = next_m / (op_sqrt(next_v) + eps)
if decay_flag: if decay_flag:
update = update + op_mul(weight_decay_tensor, param) update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(lr, update) update_with_lr = op_mul(lr, update)
next_param = param - op_reshape(update_with_lr, op_shape(param)) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param)) next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m)) next_v = F.depend(next_v, F.assign(m, next_m))

@ -67,23 +67,23 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
op_fill = P.Fill() op_fill = P.Fill()
op_dtype = P.DType() op_dtype = P.DType()
param = op_cast(param, mstype.float32) param_fp32 = op_cast(param, mstype.float32)
m = op_cast(m, mstype.float32) m_fp32 = op_cast(m, mstype.float32)
v = op_cast(v, mstype.float32) v_fp32 = op_cast(v, mstype.float32)
gradient = op_cast(gradient, mstype.float32) gradient_fp32 = op_cast(gradient, mstype.float32)
next_m = op_mul(beta1, m) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient) next_m = op_mul(beta1, m_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta1, gradient_fp32)
next_v = op_mul(beta2, v) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient)) next_v = op_mul(beta2, v_fp32) + op_mul(op_cast(num_one, mstype.float32) - beta2, op_square(gradient_fp32))
next_mm = next_m / (op_cast(num_one, mstype.float32) next_mm = next_m / (op_cast(num_one, mstype.float32)
- op_pow(beta1, op_cast(global_step + num_one, mstype.float32))) - op_pow(beta1, op_cast(global_step + num_one, mstype.float32)))
next_vv = next_v / (op_cast(num_one, mstype.float32) - next_vv = next_v / (op_cast(num_one, mstype.float32) -
op_pow(beta2, op_cast(global_step + num_one, mstype.float32))) op_pow(beta2, op_cast(global_step + num_one, mstype.float32)))
w_norm = op_norm(param) w_norm = op_norm(param_fp32)
g_norm = op_norm(gradient) g_norm = op_norm(gradient_fp32)
g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param) g_norm_hat = op_norm(op_mul(next_mm, op_rsqrt(next_vv + eps)) + weight_decay_tensor * param_fp32)
zeros = F.zeros_like_tensor(w_norm) zeros = F.zeros_like_tensor(w_norm)
ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0) ones = op_fill(op_dtype(w_norm), op_shape(w_norm), 1.0)
trust_ratio = op_select( trust_ratio = op_select(
@ -95,11 +95,11 @@ def _update_run_op(beta1, beta2, eps, lr, weight_decay_tensor, global_step, para
update = next_mm / (op_sqrt(next_vv) + eps) update = next_mm / (op_sqrt(next_vv) + eps)
if decay_flag: if decay_flag:
update = update + op_mul(weight_decay_tensor, param) update = update + op_mul(weight_decay_tensor, param_fp32)
update_with_lr = op_mul(op_mul(trust_ratio, lr), update) update_with_lr = op_mul(op_mul(trust_ratio, lr), update)
next_param = param - op_reshape(update_with_lr, op_shape(param)) next_param = param_fp32 - op_reshape(update_with_lr, op_shape(param_fp32))
next_v = F.depend(next_v, F.assign(param, next_param)) next_v = F.depend(next_v, F.assign(param, next_param))
next_v = F.depend(next_v, F.assign(m, next_m)) next_v = F.depend(next_v, F.assign(m, next_m))

@ -24,6 +24,8 @@ import itertools
import numbers import numbers
import numpy as np import numpy as np
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ..._checkparam import Rel from ..._checkparam import Rel
from ...common import dtype as mstype from ...common import dtype as mstype
@ -1965,6 +1967,11 @@ class ScatterNdUpdate(PrimitiveWithInfer):
>>> op = P.ScatterNdUpdate() >>> op = P.ScatterNdUpdate()
>>> output = op(input_x, indices, update) >>> output = op(input_x, indices, update)
""" """
__mindspore_signature__ = (
('input_x', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('indices', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
)
@prim_attr_register @prim_attr_register
def __init__(self, use_locking=True): def __init__(self, use_locking=True):

@ -14,19 +14,20 @@
# ============================================================================ # ============================================================================
import pytest import pytest
from mindspore import Tensor from mindspore import Tensor, Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
import numpy as np import numpy as np
import mindspore.context as context import mindspore.context as context
class AssignAdd(nn.Cell): class AssignAdd(nn.Cell):
def __init__( self): def __init__(self, value):
super(AssignAdd, self).__init__() super(AssignAdd, self).__init__()
self.var = Parameter(value, name="var")
self.add = P.AssignAdd() self.add = P.AssignAdd()
def construct(self, x, y): def construct(self, y):
res = self.add(x, y) res = self.add(self.var, y)
return res return res
@pytest.mark.level0 @pytest.mark.level0
@ -58,15 +59,17 @@ def test_assign_add():
y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32)) y2 = Tensor(np.arange(1 * 3 * 3 * 3).reshape(1, 3, 3, 3).astype(np.float32))
context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU') context.set_context(mode=context.PYNATIVE_MODE, device_target='GPU')
add = AssignAdd() add = AssignAdd(x1)
output1 = add(x1, y1) output1 = add(y1)
assert (output1.asnumpy() == expect1).all() assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y1) add = AssignAdd(output1)
output2 = add(y1)
assert (output2.asnumpy() == expect2).all() assert (output2.asnumpy() == expect2).all()
context.set_context(mode=context.GRAPH_MODE, device_target='GPU') context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
add = AssignAdd() add = AssignAdd(x2)
output1 = add(x2, y2) output1 = add(y2)
assert (output1.asnumpy() == expect1).all() assert (output1.asnumpy() == expect1).all()
output2 = add(output1, y2) add = AssignAdd(output1)
output2 = add(y2)
assert (output2.asnumpy() == expect2).all() assert (output2.asnumpy() == expect2).all()

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
import pytest import pytest
from mindspore import Tensor from mindspore import Tensor, Parameter
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.nn as nn import mindspore.nn as nn
import numpy as np import numpy as np
@ -22,12 +22,13 @@ import mindspore.context as context
class Net(nn.Cell): class Net(nn.Cell):
def __init__(self): def __init__(self, value):
super(Net, self).__init__() super(Net, self).__init__()
self.var = Parameter(value, name="var")
self.assign = P.Assign() self.assign = P.Assign()
def construct(self, var, value): def construct(self, value):
return self.assign(var, value) return self.assign(self.var, value)
x = np.array([[1.2, 1], [1, 0]]).astype(np.float32) x = np.array([[1.2, 1], [1, 0]]).astype(np.float32)
value = np.array([[1, 2], [3, 4.0]]).astype(np.float32) value = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
@ -37,13 +38,13 @@ value = np.array([[1, 2], [3, 4.0]]).astype(np.float32)
@pytest.mark.env_onecard @pytest.mark.env_onecard
def test_assign(): def test_assign():
context.set_context(mode=context.GRAPH_MODE, device_target="GPU") context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
assign = Net()
var = Tensor(x) var = Tensor(x)
output = assign(var, Tensor(value)) assign = Net(var)
output = assign(Tensor(value))
error = np.ones(shape=[2, 2]) * 1.0e-6 error = np.ones(shape=[2, 2]) * 1.0e-6
diff1 = output.asnumpy() - value diff1 = output.asnumpy() - value
diff2 = var.asnumpy() - value diff2 = assign.var.default_input.asnumpy() - value
assert np.all(diff1 < error) assert np.all(diff1 < error)
assert np.all(-diff1 < error) assert np.all(-diff1 < error)
assert np.all(diff2 < error) assert np.all(diff2 < error)

@ -341,6 +341,15 @@ class SignNet(nn.Cell):
def construct(self, x): def construct(self, x):
return self.sign(x) return self.sign(x)
class AssignAdd(nn.Cell):
def __init__(self):
super().__init__()
self.op = P.AssignAdd()
self.inputdata = Parameter(initializer(1, [1], ms.float32), name="global_step")
def construct(self, input_):
self.inputdata = input_
return self.op(self.inputdata, input_)
test_case_math_ops = [ test_case_math_ops = [
('MatMulGrad', { ('MatMulGrad', {
@ -413,6 +422,9 @@ raise_set = [
('StridedSlice_4_Error', { ('StridedSlice_4_Error', {
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}), 'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
'desc_inputs': [0]}), 'desc_inputs': [0]}),
('AssignAdd_Error', {
'block': (P.AssignAdd(), {'exception': TypeError}),
'desc_inputs': [[1]]}),
] ]

@ -38,8 +38,7 @@ def tensor_run_opt(opt, iters, learning_rate, momentum,
gradient, variable, moment): gradient, variable, moment):
""" tensor_run_opt """ """ tensor_run_opt """
success = True success = True
new_weight = opt(gradient, moment, variable, new_weight = opt(variable, moment, learning_rate, gradient, momentum)
learning_rate, momentum)
success = F.depend(success, F.assign(variable, new_weight)) success = F.depend(success, F.assign(variable, new_weight))
return success return success

@ -446,12 +446,6 @@ test_cases = [
'desc_inputs': [[128, 32, 32, 64]], 'desc_inputs': [[128, 32, 32, 64]],
'desc_bprop': [[128, 32, 32, 64]], 'desc_bprop': [[128, 32, 32, 64]],
}), }),
('ApplyMomentum', {
'block': P.ApplyMomentum(),
'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
'desc_bprop': [[128, 32, 32, 64]],
'skip': ['backward']
}),
('ScalarSummary', { ('ScalarSummary', {
'block': ScalarSummaryNet(), 'block': ScalarSummaryNet(),
'desc_inputs': [2.2], 'desc_inputs': [2.2],
@ -515,6 +509,12 @@ test_cases = [
] ]
test_cases_for_verify_exception = [ test_cases_for_verify_exception = [
('ApplyMomentum_Error', {
'block': (P.ApplyMomentum(), {'exception': TypeError}),
'desc_inputs': [[2], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64], [128, 32, 32, 64]],
'desc_bprop': [[128, 32, 32, 64]],
'skip': ['backward']
}),
('Conv2d_ValueError_1', { ('Conv2d_ValueError_1', {
'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}), 'block': (lambda _: P.Conv2D(3, 4, mode=-2.0), {'exception': TypeError}),
'desc_inputs': [0], 'desc_inputs': [0],

@ -674,12 +674,6 @@ test_case_nn_ops = [
'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]], 'desc_inputs': [[128, 64, 32, 32], [128, 64, 32, 32], [64], [64], [64], [64]],
'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]], 'desc_bprop': [[128, 64, 32, 32], [64], [64], [64], [64]],
'skip': ['backward']}), 'skip': ['backward']}),
('ApplyMomentum', {
'block': P.ApplyMomentum(),
'desc_inputs': [[128, 32, 32, 64], [128, 32, 32, 64],
[32, 32, 64], [32, 32, 64], [32, 32, 64]],
'desc_bprop': [[128, 32, 32, 64]],
'skip': ['backward']}),
('TopK', { ('TopK', {
'block': P.TopK(), 'block': P.TopK(),
'desc_const': [5], 'desc_const': [5],
@ -1113,12 +1107,6 @@ test_case_other_ops = [
'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)), 'desc_inputs': (Tensor(np.ones((1, 3, 6, 6), np.float32)),
Tensor(np.ones((2, 4), np.int32))), Tensor(np.ones((2, 4), np.int32))),
'desc_bprop': [[2]]}), 'desc_bprop': [[2]]}),
('ScatterNdUpdate', {
'block': P.ScatterNdUpdate(),
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.float32))),
'desc_bprop': [[2, 3]]}),
('ScatterNd', { ('ScatterNd', {
'block': P.ScatterNd(), 'block': P.ScatterNd(),
'desc_const': [(3, 3)], 'desc_const': [(3, 3)],
@ -1178,7 +1166,7 @@ import mindspore.context as context
@non_graph_engine @non_graph_engine
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config) @mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config)
def test_exec(): def test_exec():
context.set_context(mode=context.GRAPH_MODE) context.set_context(mode=context.GRAPH_MODE, save_graphs=True)
return test_exec_case return test_exec_case
@ -1207,6 +1195,12 @@ raise_set = [
'block': (NetForFlatten0D(), {'exception': ValueError}), 'block': (NetForFlatten0D(), {'exception': ValueError}),
'desc_inputs': [Tensor(np.array(0).astype(np.int32))], 'desc_inputs': [Tensor(np.array(0).astype(np.int32))],
'desc_bprop': [Tensor(np.array(0).astype(np.int32))]}), 'desc_bprop': [Tensor(np.array(0).astype(np.int32))]}),
('ScatterNdUpdate', {
'block': (P.ScatterNdUpdate(), {'exception': TypeError}),
'desc_inputs': (Tensor(np.ones((2, 3), np.float32)),
Tensor(np.ones((2, 2), np.int32)),
Tensor(np.ones((2,), np.float32))),
'desc_bprop': [[2, 3]]}),
] ]

Loading…
Cancel
Save