hidden range

pull/2128/head
jiangjinsheng 5 years ago
parent 14b997448a
commit 032e921298

@ -45,8 +45,9 @@ class Embedding(Cell):
Inputs: Inputs:
- **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of - **input** (Tensor) - Tensor of shape :math:`(\text{batch_size}, \text{input_length})`. The element of
the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero the Tensor should be integer and not larger than vocab_size. else the corresponding embedding vector is zero
if larger than vocab_size. if larger than vocab_size.
Outputs: Outputs:
Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`. Tensor of shape :math:`(\text{batch_size}, \text{input_length}, \text{embedding_size})`.

@ -17,6 +17,7 @@
from .. import operations as P from .. import operations as P
from ..operations import _grad_ops as G from ..operations import _grad_ops as G
from ..operations import _inner_ops as inner
from ..composite.multitype_ops.zeros_like_impl import zeros_like from ..composite.multitype_ops.zeros_like_impl import zeros_like
from ..functional import broadcast_gradient_args from ..functional import broadcast_gradient_args
from .. import functional as F from .. import functional as F
@ -341,7 +342,7 @@ def get_bprop_sparse_gather_v2(self):
return bprop return bprop
@bprop_getters.register(P.Range) @bprop_getters.register(inner.Range)
def get_bprop_range(self): def get_bprop_range(self):
"""Generate bprop for Range""" """Generate bprop for Range"""

@ -23,7 +23,7 @@ from .array_ops import (Argmax, Argmin, Cast, Concat, Pack, Unpack,
Diag, DiagPart, DType, ExpandDims, Eye, Diag, DiagPart, DType, ExpandDims, Eye,
Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation, Fill, GatherNd, GatherV2, SparseGatherV2, InvertPermutation,
IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike, IsInstance, IsSubClass, ArgMaxWithValue, OnesLike, ZerosLike,
Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue, Range, Rank, Reshape, ResizeNearestNeighbor, ArgMinWithValue,
SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate, SameTypeShape, ScatterAdd, ScatterMax, ScatterUpdate,
ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select, ScalarToArray, ScalarToTensor, ScatterNd, ScatterNdUpdate, Select,
Shape, Size, Slice, Split, EmbeddingLookup, Shape, Size, Slice, Split, EmbeddingLookup,
@ -75,7 +75,7 @@ from .nn_ops import (LSTM, SGD, Adam, SparseApplyAdam, SparseApplyLazyAdam, Appl
ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2, ApplyAdaMax, ApplyAdadelta, ApplyAdagrad, ApplyAdagradV2,
ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK) ApplyRMSProp, ApplyCenteredRMSProp, BasicLSTMCell, InTopK)
from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode, from .other_ops import (Assign, IOU, BoundingBoxDecode, BoundingBoxEncode,
CheckValid, MakeRefKey, Partial, Depend, CheckBprop, ConfusionMatrix) CheckValid, MakeRefKey, Partial, Depend, CheckBprop)
from . import _quant_ops from . import _quant_ops
from ._quant_ops import * from ._quant_ops import *
from .thor_ops import * from .thor_ops import *
@ -303,13 +303,12 @@ __all__ = [
"Atan", "Atan",
"Atanh", "Atanh",
"BasicLSTMCell", "BasicLSTMCell",
"ConfusionMatrix",
"BroadcastTo", "BroadcastTo",
"Range",
"DataFormatDimMap", "DataFormatDimMap",
"ApproximateEqual", "ApproximateEqual",
"InplaceUpdate", "InplaceUpdate",
"InTopK", "InTopK",
"DataFormatDimMap"
] ]
__all__.extend(_quant_ops.__all__) __all__.extend(_quant_ops.__all__)

@ -15,9 +15,10 @@
"""Inner operators.""" """Inner operators."""
from ..._checkparam import Rel
from ..._checkparam import Validator as validator from ..._checkparam import Validator as validator
from ...common import dtype as mstype from ...common import dtype as mstype
from ..primitive import PrimitiveWithInfer, prim_attr_register from ..primitive import PrimitiveWithInfer, prim_attr_register
class ExtractImagePatches(PrimitiveWithInfer): class ExtractImagePatches(PrimitiveWithInfer):
@ -96,3 +97,61 @@ class ExtractImagePatches(PrimitiveWithInfer):
"""infer dtype""" """infer dtype"""
validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name) validator.check_tensor_type_same({"input_x": input_x}, mstype.number_type, self.name)
return input_x return input_x
class Range(PrimitiveWithInfer):
r"""
Creates a sequence of numbers.
Set `input_x` as :math:`x_i` for each element, `output` as follows:
.. math::
\text{output}(x_i) = x_i * \text{delta} + \text{start}
Args:
start (float): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`. It can not be equal to `start`.
delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
Inputs:
- **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
Outputs:
Tensor, has the same shape and dtype as `input_x`.
Examples:
>>> range = P.Range(1.0, 8.0, 2.0)
>>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
>>> range(x)
[3, 5, 7, 5]
"""
@prim_attr_register
def __init__(self, start, limit=None, delta=1.0):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
self.delta = validator.check_value_type("delta", delta, [float], self.name)
validator.check_value_type("start", start, [float], self.name)
if limit is None:
self.start = 0.0
self.limit = start
self.add_prim_attr("start", self.start)
self.add_prim_attr("limit", self.limit)
else:
validator.check_value_type("limit", limit, [float], self.name)
validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name)
if self.delta == 0.0:
raise ValueError("The input of `delta` can not be equal to zero.")
if self.delta > 0.0 and self.start > self.limit:
raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, "
f"but got start:{self.start}, limit:{self.limit}")
if self.delta < 0.0 and self.start < self.limit:
raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, "
f"but got start:{self.start}, limit:{self.limit}")
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
return x_dtype

@ -556,64 +556,6 @@ class SparseGatherV2(GatherV2):
""" """
class Range(PrimitiveWithInfer):
r"""
Creates a sequence of numbers.
Set `input_x` as :math:`x_i` for each element, `output` as follows:
.. math::
\text{output}(x_i) = x_i * \text{delta} + \text{start}
Args:
start (float): If `limit` is `None`, the value acts as limit in the range and first entry
defaults to `0`. Otherwise, it acts as first entry in the range.
limit (float): Acts as upper limit of sequence. If `None`, defaults to the value of `start`
while set the first entry of the range to `0`. It can not be equal to `start`.
delta (float): Increment of the range. It can not be equal to zero. Default: 1.0.
Inputs:
- **input_x** (Tensor) - The assistant data. A `1-D` tensor of type float32 or int32.
Outputs:
Tensor, has the same shape and dtype as `input_x`.
Examples:
>>> range = P.Range(1.0, 8.0, 2.0)
>>> x = Tensor(np.array([1, 2, 3, 2]), mindspore.int32)
>>> range(x)
[3, 5, 7, 5]
"""
@prim_attr_register
def __init__(self, start, limit=None, delta=1.0):
self.init_prim_io_names(inputs=['x'], outputs=['y'])
self.delta = validator.check_value_type("delta", delta, [float], self.name)
validator.check_value_type("start", start, [float], self.name)
if limit is None:
self.start = 0.0
self.limit = start
self.add_prim_attr("start", self.start)
self.add_prim_attr("limit", self.limit)
else:
validator.check_value_type("limit", limit, [float], self.name)
validator.check('start', self.start, 'limit', self.limit, Rel.NE, self.name)
if self.delta == 0.0:
raise ValueError("The input of `delta` can not be equal to zero.")
if self.delta > 0.0 and self.start > self.limit:
raise ValueError(f"Limit should be greater than start when delta:{self.delta} is more than zero, "
f"but got start:{self.start}, limit:{self.limit}")
if self.delta < 0.0 and self.start < self.limit:
raise ValueError(f"Start should be greater than limit when delta:{self.delta} is less than zero, "
f"but got start:{self.start}, limit:{self.limit}")
def infer_shape(self, x_shape):
return x_shape
def infer_dtype(self, x_dtype):
validator.check_tensor_type_same({'x_dtype': x_dtype}, [mstype.float32, mstype.int32], self.name)
return x_dtype
class EmbeddingLookup(PrimitiveWithInfer): class EmbeddingLookup(PrimitiveWithInfer):
""" """
Returns a slice of input tensor based on the specified indices. Returns a slice of input tensor based on the specified indices.

@ -25,6 +25,7 @@ from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn import Cell from mindspore.nn import Cell
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _inner_ops as inner
from mindspore.ops import prim_attr_register from mindspore.ops import prim_attr_register
from mindspore.ops.primitive import PrimitiveWithInfer from mindspore.ops.primitive import PrimitiveWithInfer
import mindspore.context as context import mindspore.context as context
@ -286,19 +287,10 @@ class SpaceToBatchNDNet(Cell):
return self.space_to_batch_nd(x) return self.space_to_batch_nd(x)
class ConfusionMatrixNet(Cell):
def __init__(self):
super(ConfusionMatrixNet, self).__init__()
self.confusion_matrix = P.ConfusionMatrix(4, "int32")
def construct(self, x, y):
return self.confusion_matrix(x, y)
class RangeNet(Cell): class RangeNet(Cell):
def __init__(self): def __init__(self):
super(RangeNet, self).__init__() super(RangeNet, self).__init__()
self.range_ops = P.Range(1.0, 8.0, 2.0) self.range_ops = inner.Range(1.0, 8.0, 2.0)
def construct(self, x): def construct(self, x):
return self.range_ops(x) return self.range_ops(x)
@ -344,9 +336,6 @@ test_case_array_ops = [
('BatchToSpaceNDNet', { ('BatchToSpaceNDNet', {
'block': BatchToSpaceNDNet(), 'block': BatchToSpaceNDNet(),
'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}), 'desc_inputs': [Tensor(np.random.rand(4, 1, 1, 1).astype(np.float16))]}),
('ConfusionMatrixNet', {
'block': ConfusionMatrixNet(),
'desc_inputs': [Tensor([0, 1, 1, 3], ms.int32), Tensor([0, 1, 1, 3], ms.int32)]}),
('RangeNet', { ('RangeNet', {
'block': RangeNet(), 'block': RangeNet(),
'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}), 'desc_inputs': [Tensor(np.array([1, 2, 3, 2]), ms.int32)]}),

@ -25,6 +25,7 @@ from mindspore.common import dtype as mstype
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops.operations import _grad_ops as G from mindspore.ops.operations import _grad_ops as G
from mindspore.ops.operations import _inner_ops as inner
from ..ut_filter import non_graph_engine from ..ut_filter import non_graph_engine
from ....mindspore_test_framework.mindspore_test import mindspore_test from ....mindspore_test_framework.mindspore_test import mindspore_test
from ....mindspore_test_framework.pipeline.forward.compile_forward \ from ....mindspore_test_framework.pipeline.forward.compile_forward \
@ -1051,7 +1052,7 @@ test_case_nn_ops = [
'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))], 'desc_inputs': [[3, 1, 2], Tensor(np.array([0, 1]).astype(np.int32))],
'desc_bprop': [[2, 1, 2]]}), 'desc_bprop': [[2, 1, 2]]}),
('Range', { ('Range', {
'block': P.Range(1.0, 5.0), 'block': inner.Range(1.0, 5.0),
'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))], 'desc_inputs': [Tensor(np.ones([10]).astype(np.float32))],
'desc_bprop': [[10]]}), 'desc_bprop': [[10]]}),
('UnsortedSegmentSum', { ('UnsortedSegmentSum', {
@ -1454,7 +1455,7 @@ test_case_array_ops = [
'desc_inputs': [(Tensor(np.array([1], np.float32)), 'desc_inputs': [(Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)), Tensor(np.array([1], np.float32)),
Tensor(np.array([1], np.float32)))], Tensor(np.array([1], np.float32)))],
'desc_bprop': [[3,]]}), 'desc_bprop': [[3, ]]}),
('Pack_0', { ('Pack_0', {
'block': NetForPackInput(P.Pack()), 'block': NetForPackInput(P.Pack()),
'desc_inputs': [[2, 2], [2, 2], [2, 2]], 'desc_inputs': [[2, 2], [2, 2], [2, 2]],
@ -1527,7 +1528,7 @@ test_case_array_ops = [
Tensor(np.array([0, 1, 1]).astype(np.int32))], Tensor(np.array([0, 1, 1]).astype(np.int32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}), 'desc_bprop': [Tensor(np.array([[1, 2, 3], [4, 2, 1]]).astype(np.float32))]}),
('BroadcastTo', { ('BroadcastTo', {
'block': P.BroadcastTo((2,3)), 'block': P.BroadcastTo((2, 3)),
'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))], 'desc_inputs': [Tensor(np.array([1, 2, 3]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}), 'desc_bprop': [Tensor(np.array([[1, 2, 3], [1, 2, 3]]).astype(np.float32))]}),
('InTopK', { ('InTopK', {

Loading…
Cancel
Save