!248 Add 5 random ops

Merge pull request !248 from peixu_ren/custom_aicpu
pull/3198/head
mindspore-ci-bot 5 years ago committed by Gitee
commit 99d4289251

@ -26,7 +26,6 @@ from .squeeze import _squeeze_aicpu
from .expand_dims import _expand_dims_aicpu
from .random_choice_with_mask import _random_choice_with_mask_aicpu
from .pack import _pack_aicpu
from .normal import _normal_aicpu
from .ctcloss import _ctcloss_aicpu
from .reverse_sequence import _reverse_sequence_aicpu
from .crop_and_resize import _crop_and_resize_aicpu
@ -34,3 +33,8 @@ from .rnnt_loss import _rnnt_loss_aicpu
from .random_categorical import _random_categorical_aicpu
from .cast import _cast_aicpu
from .mirror_pad import _mirror_pad_aicpu
from .normal import _normal_aicpu
from .gamma import _gamma_aicpu
from .poisson import _poisson_aicpu
from .uniform_int import _uniform_int_aicpu
from .uniform_real import _uniform_real_aicpu

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomGamma op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
gamma_op_info = AiCPURegOp("Gamma") \
.fusion_type("OPAQUE") \
.input(0, "shape", "required") \
.input(1, "alpha", "required") \
.input(2, "beta", "required") \
.output(0, "output", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()
@op_info_register(gamma_op_info)
def _gamma_aicpu():
"""RandomGamma AiCPU register"""
return

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Normal op"""
"""RandomNormal op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
normal_op_info = AiCPURegOp("Normal") \
@ -21,7 +21,7 @@ normal_op_info = AiCPURegOp("Normal") \
.input(0, "shape", "required") \
.input(1, "mean", "required") \
.input(2, "stddev", "required") \
.output(0, "y", "required") \
.output(0, "output", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
@ -29,5 +29,5 @@ normal_op_info = AiCPURegOp("Normal") \
@op_info_register(normal_op_info)
def _normal_aicpu():
"""Normal AiCPU register"""
"""RandomNormal AiCPU register"""
return

@ -0,0 +1,32 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomPoisson op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
poisson_op_info = AiCPURegOp("Poisson") \
.fusion_type("OPAQUE") \
.input(0, "shape", "required") \
.input(1, "mean", "required") \
.output(0, "output", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.I32_NCHW) \
.get_op_info()
@op_info_register(poisson_op_info)
def _poisson_aicpu():
"""RandomPoisson AiCPU register"""
return

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomUniformInt op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
uniform_int_op_info = AiCPURegOp("UniformInt") \
.fusion_type("OPAQUE") \
.input(0, "shape", "required") \
.input(1, "a", "required") \
.input(2, "b", "required") \
.output(0, "output", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.I32_Default, DataType.I32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW, DataType.I32_NCHW) \
.get_op_info()
@op_info_register(uniform_int_op_info)
def _uniform_int_aicpu():
"""RandomUniformInt AiCPU register"""
return

@ -0,0 +1,33 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""RandomUniformReal op"""
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
uniform_real_op_info = AiCPURegOp("UniformReal") \
.fusion_type("OPAQUE") \
.input(0, "shape", "required") \
.input(1, "a", "required") \
.input(2, "b", "required") \
.output(0, "output", "required") \
.attr("seed", "int") \
.dtype_format(DataType.I32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.I32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW, DataType.F32_NCHW) \
.get_op_info()
@op_info_register(uniform_real_op_info)
def _uniform_real_aicpu():
"""RandomUniformReal AiCPU register"""
return

@ -54,7 +54,8 @@ from .math_ops import (Abs, ACos, Asin, Asinh, AddN, AccumulateNV2, AssignAdd, A
Sin, Sqrt, Rsqrt, BesselI0e, BesselI1e,
Square, Sub, TensorAdd, Sign, Round, SquareSumAll, Atan, Atanh, Cosh, Sinh, Eps)
from .random_ops import (RandomChoiceWithMask, Normal, RandomCategorical)
from .random_ops import (RandomChoiceWithMask, Normal, Gamma, Poisson, UniformInt, UniformReal,
RandomCategorical)
from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, ApplyMomentum, BatchNorm,
BiasAdd, Conv2D,
DepthwiseConv2dNative,
@ -172,6 +173,10 @@ __all__ = [
'Tanh',
'RandomChoiceWithMask',
'Normal',
'Gamma',
'Poisson',
'UniformInt',
'UniformReal',
'RandomCategorical',
'ResizeBilinear',
'ScalarSummary',

File diff suppressed because it is too large Load Diff

@ -0,0 +1,58 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.gamma = P.Gamma(seed=seed)
self.shape = shape
def construct(self, alpha, beta):
return self.gamma(self.shape, alpha, beta)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
alpha = 1.0
beta = 1.0
net = Net(shape, seed)
talpha, tbeta = Tensor(alpha, mstype.float32), Tensor(beta, mstype.float32)
output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
alpha = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
beta = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
talpha, tbeta = Tensor(alpha), Tensor(beta)
output = net(talpha, tbeta)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

@ -12,32 +12,46 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import Tensor
from mindspore.common import dtype as mstype
context.set_context(mode=context.PYNATIVE_MODE, device_target="Ascend")
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self._mean = Tensor(mean, mstype.float32)
self._stddev = Tensor(stddev, mstype.float32)
self._normal = P.Normal(seed=seed)
self._shape = shape
self.normal = P.Normal(seed=seed)
self.shape = shape
def construct(self):
return self._normal(self._shape, self._mean, self._stddev)
def construct(self, mean, stddev):
return self.normal(self.shape, mean, stddev)
def test_net_3x2x4():
mean = 0.0
def test_net_1D():
seed = 10
shape = (3, 2, 4)
mean = 1.0
stddev = 1.0
seed = 0
net = Net((3, 2, 4), mean, stddev, seed)
out = net()
assert out.shape == (3, 2, 4)
net = Net(shape, seed)
tmean, tstddev = Tensor(mean, mstype.float32), Tensor(stddev, mstype.float32)
output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 1, 2)
mean = np.array([[[1], [2]], [[3], [4]], [[5], [6]]]).astype(np.float32)
stddev = np.array([1.0]).astype(np.float32)
net = Net(shape, seed)
tmean, tstddev = Tensor(mean), Tensor(stddev)
output = net(tmean, tstddev)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

@ -0,0 +1,53 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape):
super(Net, self).__init__()
self.poisson = P.Poisson()
self.shape = shape
def construct(self, mean):
return self.poisson(self.shape, mean)
def test_net_1():
shape = (2, 16)
mean = np.array([5.0]).astype(np.float32)
net = Net(shape)
tmean = Tensor(mean)
output = net(tmean)
print(output.asnumpy())
assert output.shape == (2, 16)
def test_net_2():
shape = (4, 1)
mean = np.array([5.0, 10.0]).astype(np.float32)
net = Net(shape)
tmean = Tensor(mean)
output = net(tmean)
print(output.asnumpy())
assert output.shape == (4, 2)

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.uniformint = P.UniformInt(seed=seed)
self.shape = shape
def construct(self, a, b):
return self.uniformint(self.shape, a, b)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
a = 1
b = 5
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.int32), Tensor(b, mstype.int32)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.int32)
b = np.array([10]).astype(np.int32)
net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

@ -0,0 +1,57 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import numpy as np
import mindspore.context as context
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.common import dtype as mstype
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend")
class Net(nn.Cell):
def __init__(self, shape, seed=0):
super(Net, self).__init__()
self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape
def construct(self, a, b):
return self.uniformreal(self.shape, a, b)
def test_net_1D():
seed = 10
shape = (3, 2, 4)
a = 1.0
b = 5.0
net = Net(shape, seed)
ta, tb = Tensor(a, mstype.float32), Tensor(b, mstype.float32)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 4)
def test_net_ND():
seed = 10
shape = (3, 2, 1)
a = np.array([[[1, 2]], [[3, 4]], [[5, 6]]]).astype(np.float32)
b = np.array([10]).astype(np.float32)
net = Net(shape, seed)
ta, tb = Tensor(a), Tensor(b)
output = net(ta, tb)
print(output.asnumpy())
assert output.shape == (3, 2, 2)

@ -400,15 +400,57 @@ class InplaceSubNet(nn.Cell):
class NormalNet(nn.Cell):
def __init__(self, shape=None, mean=0.0, stddev=1.0, seed=0):
def __init__(self, shape=None, seed=0):
super(NormalNet, self).__init__()
self.normal = P.Normal(seed=seed)
self.shape = shape
self.mean = Tensor(mean, mstype.float32)
self.stddev = Tensor(stddev, mstype.float32)
def construct(self):
out = self.normal(self.shape, self.mean, self.stddev)
def construct(self, mean, stddev):
out = self.normal(self.shape, mean, stddev)
return out
class GammaNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(GammaNet, self).__init__()
self.gamma = P.Gamma(seed=seed)
self.shape = shape
def construct(self, alpha, beta):
out = self.gamma(self.shape, alpha, beta)
return out
class PoissonNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(PoissonNet, self).__init__()
self.poisson = P.Poisson(seed=seed)
self.shape = shape
def construct(self, mean):
out = self.poisson(self.shape, mean)
return out
class UniformIntNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(UniformIntNet, self).__init__()
self.uniformint = P.UniformInt(seed=seed)
self.shape = shape
def construct(self, a, b):
out = self.uniformint(self.shape, a, b)
return out
class UniformRealNet(nn.Cell):
def __init__(self, shape=None, seed=0):
super(UniformRealNet, self).__init__()
self.uniformreal = P.UniformReal(seed=seed)
self.shape = shape
def construct(self, a, b):
out = self.uniformreal(self.shape, a, b)
return out
@ -620,6 +662,26 @@ test_case_math_ops = [
(1, 1, 1)],
'desc_inputs': [[64, 128, 1024]],
'skip': ['backward']}),
('Normal', {
'block': NormalNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],
'skip': ['backward']}),
('Gamma', {
'block': GammaNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(1.0, mstype.float32)],
'skip': ['backward']}),
('Poisson', {
'block': PoissonNet((3, 2, 4), 0),
'desc_inputs': [Tensor(2.0, mstype.float32)],
'skip': ['backward']}),
('UniformInt', {
'block': UniformIntNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1, mstype.int32), Tensor(15, mstype.int32)],
'skip': ['backward']}),
('UniformReal', {
'block': UniformRealNet((3, 2, 4), 0),
'desc_inputs': [Tensor(1.0, mstype.float32), Tensor(5.0, mstype.float32)],
'skip': ['backward']}),
('RandomChoiceWithMask', {
'block': P.RandomChoiceWithMask(256),
'desc_inputs': [Tensor(np.random.rand(24000, 4).astype(np.bool_))],
@ -908,10 +970,6 @@ test_case_math_ops = [
'desc_inputs': [Tensor([-1.0, 0.0, 1.5, 2.0, 5.0, 15], mstype.float16), Tensor([0.0, 5.0], mstype.float16)],
'desc_bprop': [],
'skip': ['backward']}),
('Normal', {
'block': NormalNet((3, 2, 4), 0.0, 1.0, 0),
'desc_inputs': [],
'skip': ['backward']}),
]
test_case_nn_ops = [

Loading…
Cancel
Save