initial commit

fix testcases

fix ci
pull/9245/head
Peilin Wang 4 years ago
parent 689f102f86
commit e4444a1c12

@ -290,8 +290,8 @@ class SampledSoftmaxLoss(_Loss):
num_classes (int): The number of possible classes. num_classes (int): The number of possible classes.
num_true (int): The number of target classes per training example. num_true (int): The number of target classes per training example.
sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`, sampled_values (Tuple): Tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `*_candidate_sampler` function. `sampled_expected_count`) returned by a `*CandidateSampler` function.
Default to None, `log_uniform_candidate_sampler` is applied. Default to None, `UniformCandidateSampler` is applied.
remove_accidental_hits (bool): Whether to remove "accidental hits" remove_accidental_hits (bool): Whether to remove "accidental hits"
where a sampled class equals one of the target classes. Default is True. where a sampled class equals one of the target classes. Default is True.
seed (int): Random seed for candidate sampling. Default: 0 seed (int): Random seed for candidate sampling. Default: 0
@ -301,7 +301,7 @@ class SampledSoftmaxLoss(_Loss):
Inputs: Inputs:
- **weights** (Tensor) - Tensor of shape (C, dim). - **weights** (Tensor) - Tensor of shape (C, dim).
- **bias** (Tensor) - Tensor of shape (C). The class biases. - **bias** (Tensor) - Tensor of shape (C). The class biases.
- **labels** (Tensor) - Tensor of shape (N, num_true), type `int64`. The - **labels** (Tensor) - Tensor of shape (N, num_true), type `int64, int32`. The
target classes. target classes.
- **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of - **inputs** (Tensor) - Tensor of shape (N, dim). The forward activations of
the input network. the input network.
@ -414,7 +414,7 @@ class SampledSoftmaxLoss(_Loss):
activations of the input network. activations of the input network.
num_true (int): The number of target classes per training example. num_true (int): The number of target classes per training example.
sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`, sampled_values: a tuple of (`sampled_candidates`, `true_expected_count`,
`sampled_expected_count`) returned by a `UniformSampler` function. `sampled_expected_count`) returned by a `UniformCandidateSampler` function.
subtract_log_q: A `bool`. whether to subtract the log expected count of subtract_log_q: A `bool`. whether to subtract the log expected count of
the labels in the sample to get the logits of the true labels. the labels in the sample to get the logits of the true labels.
Default is True. Default is True.

@ -27,7 +27,7 @@ from .multitype_ops.add_impl import hyper_add
from .multitype_ops.ones_like_impl import ones_like from .multitype_ops.ones_like_impl import ones_like
from .multitype_ops.zeros_like_impl import zeros_like from .multitype_ops.zeros_like_impl import zeros_like
from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial from .random_ops import normal, laplace, uniform, gamma, poisson, multinomial
from .math_ops import count_nonzero, TensorDot from .math_ops import count_nonzero, tensor_dot
from .array_ops import repeat_elements from .array_ops import repeat_elements
@ -52,5 +52,5 @@ __all__ = [
'clip_by_value', 'clip_by_value',
'clip_by_global_norm', 'clip_by_global_norm',
'count_nonzero', 'count_nonzero',
'TensorDot', 'tensor_dot',
'repeat_elements'] 'repeat_elements']

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""math Operations.""" """array Operations."""
from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils from mindspore.ops.composite.multitype_ops import _constexpr_utils as const_utils
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore._checkparam import Validator as validator from mindspore._checkparam import Validator as validator
@ -69,7 +69,7 @@ def repeat_elements(x, rep, axis=0):
Examples: Examples:
>>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32) >>> x = Tensor(np.array([[0, 1, 2], [3, 4, 5]]), mindspore.int32)
>>> output = C.RepeatElements(x, rep = 2, axis = 0) >>> output = C.repeat_elements(x, rep = 2, axis = 0)
>>> print(output) >>> print(output)
[[0, 1, 2], [[0, 1, 2],
[0, 1, 2], [0, 1, 2],

@ -75,7 +75,7 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
return nonzero_num return nonzero_num
# TensorDot # tensor dot
@constexpr @constexpr
def _int_to_tuple_conv(axes): def _int_to_tuple_conv(axes):
""" """
@ -92,7 +92,7 @@ def _check_axes(axes):
""" """
Check for validity and type of axes passed to function. Check for validity and type of axes passed to function.
""" """
validator.check_value_type('axes', axes, [int, tuple, list], "TensorDot") validator.check_value_type('axes', axes, [int, tuple, list], "tensor dot")
if not isinstance(axes, int): if not isinstance(axes, int):
axes = list(axes) # to avoid immutability issues axes = list(axes) # to avoid immutability issues
if len(axes) != 2: if len(axes) != 2:
@ -156,7 +156,7 @@ def _calc_new_shape(shape, axes, position=0):
return new_shape, transpose_perm, free_dims return new_shape, transpose_perm, free_dims
def TensorDot(x1, x2, axes): def tensor_dot(x1, x2, axes):
""" """
Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`. Computation of Tensor contraction on arbitrary axes between tensors `a` and `b`.
@ -171,8 +171,8 @@ def TensorDot(x1, x2, axes):
axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b` axes = 2 is the same as axes = ((0,1),(1,2)) where length of input shape is 3 for both `a` and `b`
Inputs: Inputs:
- **x1** (Tensor) - First tensor in TensorDot op with datatype float16 or float32 - **x1** (Tensor) - First tensor in tensor_dot with datatype float16 or float32
- **x2** (Tensor) - Second tensor in TensorDot op with datatype float16 or float32 - **x2** (Tensor) - Second tensor in tensor_dot with datatype float16 or float32
- **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or - **axes** (Union[int, tuple(int), tuple(tuple(int)), list(list(int))]) - Single value or
tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed, tuple/list of length 2 with dimensions specified for `a` and `b` each. If single value `N` passed,
automatically picks up first N dims from `a` input shape and last N dims from `b` input shape. automatically picks up first N dims from `a` input shape and last N dims from `b` input shape.
@ -184,7 +184,7 @@ def TensorDot(x1, x2, axes):
Examples: Examples:
>>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32) >>> input_x1 = Tensor(np.ones(shape=[1, 2, 3]), mindspore.float32)
>>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32) >>> input_x2 = Tensor(np.ones(shape=[3, 1, 2]), mindspore.float32)
>>> output = C.TensorDot(input_x1, input_x2, ((0,1),(1,2))) >>> output = C.tensor_dot(input_x1, input_x2, ((0,1),(1,2)))
>>> print(output) >>> print(output)
[[2,2,2], [[2,2,2],
[2,2,2], [2,2,2],
@ -206,7 +206,7 @@ def TensorDot(x1, x2, axes):
x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0) x1_reshape_fwd, x1_transpose_fwd, x1_ret = _calc_new_shape(x1_shape, axes, 0)
x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1) x2_reshape_fwd, x2_transpose_fwd, x2_ret = _calc_new_shape(x2_shape, axes, 1)
output_shape = x1_ret + x2_ret # combine free axes from both inputs output_shape = x1_ret + x2_ret # combine free axes from both inputs
# run TensorDot op # run tensor_dot op
x1_transposed = transpose_op(x1, x1_transpose_fwd) x1_transposed = transpose_op(x1, x1_transpose_fwd)
x2_transposed = transpose_op(x2, x2_transpose_fwd) x2_transposed = transpose_op(x2, x2_transpose_fwd)
x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd) x1_reshaped = reshape_op(x1_transposed, x1_reshape_fwd)

@ -723,8 +723,10 @@ class Unique(Primitive):
- **x** (Tensor) - The input tensor. - **x** (Tensor) - The input tensor.
Outputs: Outputs:
Tuple, containing Tensor objects `(y, idx)`, `y` is a tensor has the same type as `x`, `idx` is a tensor Tuple, containing Tensor objects `(y, idx)., `y` is a tensor with the
containing indices of elements in the input coressponding to the output tensor. same type as `x`, and contains the unique elements in `x`, sorted in
ascending order. `idx` is a tensor containing indices of elements in
the input corresponding to the output tensor.
Supported Platforms: Supported Platforms:
``Ascend`` ``GPU`` ``CPU`` ``Ascend`` ``GPU`` ``CPU``
@ -734,6 +736,23 @@ class Unique(Primitive):
>>> output = ops.Unique()(x) >>> output = ops.Unique()(x)
>>> print(output) >>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1])) (Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
>>>
>>> # note that for GPU, this operator must be wrapped inside a model, and executed in graph mode.
>>> class UniqueNet(nn.Cell):
>>> def __init__(self):
>>> super(UniqueNet, self).__init__()
>>> self.unique_op = P.Unique()
>>>
>>> def construct(self, x):
>>> output, indices = self.unique_op(x)
>>> return output, indices
>>>
>>> x = Tensor(np.array([1, 2, 5, 2]), mindspore.int32)
>>> context.set_context(mode=context.GRAPH_MODE, device_target="GPU")
>>> net = UniqueNet()
>>> output = net(x)
>>> print(output)
(Tensor(shape=[3], dtype=Int32, value= [1, 2, 5]), Tensor(shape=[4], dtype=Int32, value= [0, 1, 2, 1]))
""" """
@prim_attr_register @prim_attr_register

@ -29,7 +29,7 @@ class NetTensorDot(nn.Cell):
self.axes = axes self.axes = axes
def construct(self, x, y): def construct(self, x, y):
return C.TensorDot(x, y, self.axes) return C.tensor_dot(x, y, self.axes)
class GradNetwork(nn.Cell): class GradNetwork(nn.Cell):

Loading…
Cancel
Save