implicit type conversion

Signed-off-by: candanzg <zhangshucheng@huawei.com>
pull/1257/head
candanzg 5 years ago
parent 4ce1cf4529
commit 2429da19fb

File diff suppressed because it is too large Load Diff

@ -321,8 +321,8 @@ def initializer(init, shape=None, dtype=mstype.float32):
dtype (:class:`mindspore.dtype`): The type of data in initialized tensor. Default: mindspore.float32.
Returns:
Union[Tensor, Initialized], When `init` is Tensor, the return is Tensor object,
otherwise the return is Initialize object.
Union[Tensor, Initializer], When `init` is Tensor, the return is Tensor object,
otherwise the return is Initialize object.
Examples:
>>> tensor = initializer('ones', [1, 2, 3], mindspore.float32)

@ -16,6 +16,7 @@
"""Parameter for cell."""
import numbers
from copy import copy, deepcopy
from . import dtype as mstype
from .initializer import initializer, Initializer
from .tensor import Tensor, MetaTensor
from .._checkparam import _check_str_by_regular
@ -199,6 +200,10 @@ class Parameter:
elif isinstance(data, Initializer):
self.init_mode = data
data = MetaTensor(self.init_mode.dtype, self.init_mode.shape)
elif isinstance(data, int):
data = Tensor(data, dtype=mstype.int32)
elif isinstance(data, float):
data = Tensor(data, dtype=mstype.float32)
else:
data = Tensor(data)
data.init_flag = False

@ -145,8 +145,8 @@ class AssignAdd(PrimitiveWithInfer):
>>> net(value)
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
@ -189,8 +189,8 @@ class AssignSub(PrimitiveWithInfer):
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register

@ -24,6 +24,7 @@ import numpy as np
from ... import context
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..._checkparam import Validator as validator
from ..._checkparam import Rel
from ...common import dtype as mstype
@ -1489,11 +1490,13 @@ class ApplyMomentum(PrimitiveWithInfer):
Please refer to the usage in nn.ApplyMomentum.
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD),
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('accumulation', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T),
('learning_rate', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE,
sig_dtype.T),
('gradient', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('momentum', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self, use_nesterov=False, use_locking=False, gradient_scale=1.0):

@ -16,6 +16,7 @@
"""Other operators."""
from ..._c_expression import signature_rw as sig_rw
from ..._c_expression import signature_kind as sig_kind
from ..._c_expression import signature_dtype as sig_dtype
from ..._checkparam import Validator as validator, Rel
from ...common import dtype as mstype
from ..primitive import Primitive, PrimitiveWithInfer, prim_attr_register
@ -46,8 +47,8 @@ class Assign(PrimitiveWithInfer):
>>> net(x)
"""
__mindspore_signature__ = (
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD)
('variable', sig_rw.RW_WRITE, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T),
('value', sig_rw.RW_READ, sig_kind.KIND_POSITIONAL_KEYWORD, sig_kind.KIND_EMPTY_DEFAULT_VALUE, sig_dtype.T)
)
@prim_attr_register
def __init__(self):

@ -0,0 +1,244 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""multitype_ops directory test case"""
import numpy as np
from functools import partial, reduce
import mindspore.nn as nn
from mindspore import Tensor
from mindspore import dtype as mstype
from mindspore.ops import functional as F, composite as C
import mindspore.context as context
import pytest
class TensorIntAutoCast(nn.Cell):
def __init__(self,):
super(TensorIntAutoCast, self).__init__()
self.i = 2
def construct(self, t):
z = F.tensor_mul(t, self.i)
return z
class TensorFPAutoCast(nn.Cell):
def __init__(self,):
super(TensorFPAutoCast, self).__init__()
self.f = 1.2
def construct(self, t):
z = F.tensor_mul(t, self.f)
return z
class TensorBoolAutoCast(nn.Cell):
def __init__(self,):
super(TensorBoolAutoCast, self).__init__()
self.f = True
def construct(self, t):
z = F.tensor_mul(t, self.f)
return z
class TensorAutoCast(nn.Cell):
def __init__(self,):
super(TensorAutoCast, self).__init__()
def construct(self, t1, t2):
z = F.tensor_mul(t1, t2)
return z
def test_tensor_auto_cast():
context.set_context(mode=context.GRAPH_MODE)
t0 = Tensor([True, False], mstype.bool_)
t_uint8 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint8)
t_int8 = Tensor(np.ones([2, 1, 2, 2]), mstype.int8)
t_int16 = Tensor(np.ones([2, 1, 2, 2]), mstype.int16)
t_int32 = Tensor(np.ones([2, 1, 2, 2]), mstype.int32)
t_int64 = Tensor(np.ones([2, 1, 2, 2]), mstype.int64)
t_fp16 = Tensor(np.ones([2, 1, 2, 2]), mstype.float16)
t_fp32 = Tensor(np.ones([2, 1, 2, 2]), mstype.float32)
t_fp64 = Tensor(np.ones([2, 1, 2, 2]), mstype.float64)
net = TensorAutoCast()
rs = net(t_uint8, t_int8)
assert rs.dtype() == mstype.int16
rs = net(t_uint8, t_int16)
assert rs.dtype() == mstype.int16
rs = net(t_uint8, t_int32)
assert rs.dtype() == mstype.int32
rs = net(t_uint8, t_int64)
assert rs.dtype() == mstype.int64
rs = net(t_int8, t_int16)
assert rs.dtype() == mstype.int16
rs = net(t_int8, t_int32)
assert rs.dtype() == mstype.int32
rs = net(t_int8, t_int64)
assert rs.dtype() == mstype.int64
rs = net(t_int16, t_int32)
assert rs.dtype() == mstype.int32
rs = net(t_int16, t_int64)
assert rs.dtype() == mstype.int64
rs = net(t_int32, t_int64)
assert rs.dtype() == mstype.int64
rs = net(t_fp16, t_fp32)
assert rs.dtype() == mstype.float32
rs = net(t_fp16, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_fp32, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_uint8, t_fp16)
assert rs.dtype() == mstype.float16
rs = net(t_uint8, t_fp32)
assert rs.dtype() == mstype.float32
rs = net(t_uint8, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_int8, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_int16, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_int32, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_int64, t_fp64)
assert rs.dtype() == mstype.float64
rs = net(t_fp16, t_int8)
assert rs.dtype() == mstype.float16
rs = net(t_fp16, t_uint8)
assert rs.dtype() == mstype.float16
rs = net(t_fp16, t_int16)
assert rs.dtype() == mstype.float16
rs = net(t_fp16, t_int32)
assert rs.dtype() == mstype.float16
rs = net(t_fp16, t_int64)
assert rs.dtype() == mstype.float16
tint = TensorIntAutoCast()
rs = tint(t_uint8)
assert rs.dtype() == mstype.uint8
rs = tint(t_int8)
assert rs.dtype() == mstype.int8
rs = tint(t_int16)
assert rs.dtype() == mstype.int16
rs = tint(t_int32)
assert rs.dtype() == mstype.int32
rs = tint(t_int64)
assert rs.dtype() == mstype.int64
rs = tint(t_fp16)
assert rs.dtype() == mstype.float16
rs = tint(t_fp32)
assert rs.dtype() == mstype.float32
rs = tint(t_fp64)
assert rs.dtype() == mstype.float64
tfp = TensorFPAutoCast()
rs = tfp(t_uint8)
assert rs.dtype() == mstype.float32
rs = tfp(t_int8)
assert rs.dtype() == mstype.float32
rs = tfp(t_int16)
assert rs.dtype() == mstype.float32
rs = tfp(t_int32)
assert rs.dtype() == mstype.float32
rs = tfp(t_int64)
assert rs.dtype() == mstype.float32
rs = tfp(t_fp16)
assert rs.dtype() == mstype.float32
rs = tfp(t_fp32)
assert rs.dtype() == mstype.float32
rs = tfp(t_fp64)
assert rs.dtype() == mstype.float64
t_uint16 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint16)
t_uint32 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint32)
t_uint64 = Tensor(np.ones([2, 1, 2, 2]), mstype.uint64)
with pytest.raises(TypeError):
net(t_uint16, t_uint8)
with pytest.raises(TypeError):
net(t_uint16, t_int8)
with pytest.raises(TypeError):
net(t_uint16, t_int16)
with pytest.raises(TypeError):
net(t_uint16, t_int32)
with pytest.raises(TypeError):
net(t_uint16, t_int64)
with pytest.raises(TypeError):
net(t_uint32, t_uint8)
with pytest.raises(TypeError):
net(t_uint32, t_int8)
with pytest.raises(TypeError):
net(t_uint32, t_int16)
with pytest.raises(TypeError):
net(t_uint32, t_int32)
with pytest.raises(TypeError):
net(t_uint32, t_int64)
with pytest.raises(TypeError):
net(t_uint64, t_uint8)
with pytest.raises(TypeError):
net(t_uint64, t_int8)
with pytest.raises(TypeError):
net(t_uint64, t_int16)
with pytest.raises(TypeError):
net(t_uint64, t_int32)
with pytest.raises(TypeError):
net(t_uint64, t_int64)
with pytest.raises(TypeError):
net(t_uint16, t_fp16)
with pytest.raises(TypeError):
net(t_uint16, t_fp32)
with pytest.raises(TypeError):
net(t_uint16, t_fp64)
with pytest.raises(TypeError):
net(t_uint32, t_fp16)
with pytest.raises(TypeError):
net(t_uint32, t_fp32)
with pytest.raises(TypeError):
net(t_uint32, t_fp64)
with pytest.raises(TypeError):
net(t_uint64, t_fp16)
with pytest.raises(TypeError):
net(t_uint64, t_fp32)
with pytest.raises(TypeError):
net(t_uint64, t_fp64)
with pytest.raises(TypeError):
tfp(t_uint16)
with pytest.raises(TypeError):
tfp(t_uint32)
with pytest.raises(TypeError):
tfp(t_uint64)
with pytest.raises(TypeError):
tint(t_uint16)
with pytest.raises(TypeError):
tint(t_uint32)
with pytest.raises(TypeError):
tint(t_uint64)
bnet = TensorBoolAutoCast()
with pytest.raises(TypeError):
bnet(t_uint8)
with pytest.raises(TypeError):
bnet(t_int8)
with pytest.raises(TypeError):
bnet(t_int16)
with pytest.raises(TypeError):
bnet(t_int32)
with pytest.raises(TypeError):
bnet(t_int64)
with pytest.raises(TypeError):
bnet(t_fp16)
with pytest.raises(TypeError):
bnet(t_fp32)
with pytest.raises(TypeError):
bnet(t_fp64)

@ -64,7 +64,7 @@ def test_parameter_update_int32_and_tensor():
param_step = train_network.parameters_dict()['global_step']
update_global_step = ParameterUpdate(param_step)
input_step = Tensor(np.array([1000]), mstype.float32)
input_step = Tensor(np.array([1000]), mstype.int32)
_executor.compile(update_global_step, input_step)

@ -463,7 +463,7 @@ raise_set = [
'block': (lambda x: P.StridedSlice(new_axis_mask="1.1"), {'exception': TypeError}),
'desc_inputs': [0]}),
('AssignAdd_Error', {
'block': (P.AssignAdd(), {'exception': TypeError}),
'block': (P.AssignAdd(), {'exception': IndexError}),
'desc_inputs': [[1]]}),
]

File diff suppressed because it is too large Load Diff
Loading…
Cancel
Save