Add prim name to error message for array_ops

pull/344/head
fary86 5 years ago
parent 789edcb23a
commit 2078229436

File diff suppressed because it is too large Load Diff

@ -175,7 +175,7 @@ AdjointPtr DFunctor::MapMorphism(const AnfNodePtr &morph) {
UpdateAdjoint(node_adjoint);
anfnode_to_adjoin_[morph] = node_adjoint;
if (cnode_morph->stop_gradient()) {
MS_LOG(WARNING) << "MapMorphism node " << morph->ToString() << " is stopped.";
MS_LOG(DEBUG) << "MapMorphism node " << morph->ToString() << " is stopped.";
return node_adjoint;
}

@ -19,7 +19,6 @@ from mindspore._checkparam import Validator as validator
from ... import context
from ..cell import Cell
from ..._checkparam import Rel
from ..._checkparam import ParamValidator
class _PoolNd(Cell):
@ -265,11 +264,11 @@ class AvgPool1d(_PoolNd):
stride=1,
pad_mode="valid"):
super(AvgPool1d, self).__init__(kernel_size, stride, pad_mode)
ParamValidator.check_type('kernel_size', kernel_size, [int,])
ParamValidator.check_type('stride', stride, [int,])
self.pad_mode = ParamValidator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'])
ParamValidator.check_integer("kernel_size", kernel_size, 1, Rel.GE)
ParamValidator.check_integer("stride", stride, 1, Rel.GE)
validator.check_value_type('kernel_size', kernel_size, [int], self.cls_name)
validator.check_value_type('stride', stride, [int], self.cls_name)
self.pad_mode = validator.check_string('pad_mode', pad_mode.upper(), ['VALID', 'SAME'], self.cls_name)
validator.check_integer("kernel_size", kernel_size, 1, Rel.GE, self.cls_name)
validator.check_integer("stride", stride, 1, Rel.GE, self.cls_name)
self.kernel_size = (1, kernel_size)
self.stride = (1, stride)
self.avg_pool = P.AvgPool(ksize=self.kernel_size,

File diff suppressed because it is too large Load Diff

@ -0,0 +1,159 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
""" test ops """
import functools
import numpy as np
from mindspore import ops
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G
import mindspore.ops.composite as C
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.common.parameter import Parameter
from ..ut_filter import non_graph_engine
from mindspore.common.api import _executor
from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward\
import (pipeline_for_compile_forward_ge_graph_for_case_by_case_config,
pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
from ....mindspore_test_framework.pipeline.gradient.compile_gradient\
import pipeline_for_compile_grad_ge_graph_for_case_by_case_config
class ExpandDimsNet(nn.Cell):
def __init__(self, axis):
super(ExpandDimsNet, self).__init__()
self.axis = axis
self.op = P.ExpandDims()
def construct(self, x):
return self.op(x, self.axis)
class IsInstanceNet(nn.Cell):
def __init__(self, inst):
super(IsInstanceNet, self).__init__()
self.inst = inst
self.op = P.IsInstance()
def construct(self, t):
return self.op(self.inst, t)
class ReshapeNet(nn.Cell):
def __init__(self, shape):
super(ReshapeNet, self).__init__()
self.shape = shape
self.op = P.Reshape()
def construct(self, x):
return self.op(x, self.shape)
raise_set = [
# input is scala, not Tensor
('ExpandDims0', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [5.0, 1],
'skip': ['backward']}),
# axis is as a parameter
('ExpandDims1', {
'block': (P.ExpandDims(), {'exception': TypeError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 1],
'skip': ['backward']}),
# axis as an attribute, but less then lower limit
('ExpandDims2', {
'block': (ExpandDimsNet(-4), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# axis as an attribute, but greater then upper limit
('ExpandDims3', {
'block': (ExpandDimsNet(3), {'exception': ValueError, 'error_keywords': ['ExpandDims']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input is scala, not Tensor
('DType0', {
'block': (P.DType(), {'exception': TypeError, 'error_keywords': ['DType']}),
'desc_inputs': [5.0],
'skip': ['backward']}),
# input x scala, not Tensor
('SameTypeShape0', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [5.0, Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
# input y scala, not Tensor
('SameTypeShape1', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), 5.0],
'skip': ['backward']}),
# type of x and y not match
('SameTypeShape2', {
'block': (P.SameTypeShape(), {'exception': TypeError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 4]).astype(np.int32))],
'skip': ['backward']}),
# shape of x and y not match
('SameTypeShape3', {
'block': (P.SameTypeShape(), {'exception': ValueError, 'error_keywords': ['SameTypeShape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), Tensor(np.ones([3, 3]).astype(np.float32))],
'skip': ['backward']}),
# sub_type is None
('IsSubClass0', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [None, mstype.number],
'skip': ['backward']}),
# type_ is None
('IsSubClass1', {
'block': (P.IsSubClass(), {'exception': TypeError, 'error_keywords': ['IsSubClass']}),
'desc_inputs': [mstype.number, None],
'skip': ['backward']}),
# inst is var
('IsInstance0', {
'block': (P.IsInstance(), {'exception': ValueError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [5.0, mstype.number],
'skip': ['backward']}),
# t is not mstype.Type
('IsInstance1', {
'block': (IsInstanceNet(5.0), {'exception': TypeError, 'error_keywords': ['IsInstance']}),
'desc_inputs': [None],
'skip': ['backward']}),
# input x is scalar, not Tensor
('Reshape0', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [5.0, (1, 2)],
'skip': ['backward']}),
# input shape is var
('Reshape1', {
'block': (P.Reshape(), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32)), (2, 3, 2)],
'skip': ['backward']}),
# element of shape is not int
('Reshape3', {
'block': (ReshapeNet((2, 3.0, 2)), {'exception': TypeError, 'error_keywords': ['Reshape']}),
'desc_inputs': [Tensor(np.ones([3, 4]).astype(np.float32))],
'skip': ['backward']}),
]
@mindspore_test(pipeline_for_compile_forward_ge_graph_for_case_by_case_config_exception)
def test_check_exception():
return raise_set

@ -383,7 +383,7 @@ def test_tensor_slice_reduce_out_of_bounds_neg():
net = NetWork()
with pytest.raises(ValueError) as ex:
net(input_tensor)
assert "The `begin[0]` should be an int and must greater or equal to -6, but got -7" in str(ex.value)
assert "For 'StridedSlice' the `begin[0]` should be an int and must greater or equal to -6, but got `-7`" in str(ex.value)
def test_tensor_slice_reduce_out_of_bounds_positive():
@ -400,4 +400,4 @@ def test_tensor_slice_reduce_out_of_bounds_positive():
net = NetWork()
with pytest.raises(ValueError) as ex:
net(input_tensor)
assert "The `begin[0]` should be an int and must less than 6, but got 6" in str(ex.value)
assert "For 'StridedSlice' the `begin[0]` should be an int and must less than 6, but got `6`" in str(ex.value)

@ -16,7 +16,7 @@
import numpy as np
from mindspore._checkparam import Rel
from mindspore._checkparam import ParamValidator as validator
from mindspore._checkparam import Validator as validator
def avg_pooling(x, pool_h, pool_w, stride):
@ -32,7 +32,7 @@ def avg_pooling(x, pool_h, pool_w, stride):
Returns:
numpy.ndarray, an output array after applying average pooling on input array.
"""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1
@ -217,7 +217,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
dilation=1, groups=1, padding_mode='zeros'):
"""Convolution 2D."""
# pylint: disable=unused-argument
validator.check_type('stride', stride, (int, tuple))
validator.check_value_type('stride', stride, (int, tuple), None)
if isinstance(stride, int):
stride = (stride, stride)
elif len(stride) == 4:
@ -229,7 +229,7 @@ def conv2d(x, weight, bias=None, stride=1, pad=0,
f"a tuple of two positive int numbers, but got {stride}")
stride_h = stride[0]
stride_w = stride[1]
validator.check_type('dilation', dilation, (int, tuple))
validator.check_value_type('dilation', dilation, (int, tuple), None)
if isinstance(dilation, int):
dilation = (dilation, dilation)
elif len(dilation) == 4:
@ -384,7 +384,7 @@ def matmul(x, w, b=None):
def max_pooling(x, pool_h, pool_w, stride):
"""Max pooling."""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1
@ -427,7 +427,7 @@ def max_pool_grad_with_argmax(x, dout, arg_max, pool_h, pool_w, stride):
def max_pool_with_argmax(x, pool_h, pool_w, stride):
"""Max pooling with argmax."""
validator.check_integer("stride", stride, 0, Rel.GT)
validator.check_integer("stride", stride, 0, Rel.GT, None)
num, channel, height, width = x.shape
out_h = (height - pool_h)//stride + 1
out_w = (width - pool_w)//stride + 1

Loading…
Cancel
Save