|
|
|
@ -145,10 +145,10 @@ class SameTypeShape(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def __call__(self, x, y):
|
|
|
|
|
"""run in PyNative mode"""
|
|
|
|
|
validator.check_value_type("x", x, Tensor, self.name)
|
|
|
|
|
validator.check_value_type("y", y, Tensor, self.name)
|
|
|
|
|
validator.check('x dtype', x.dtype(), 'y dtype', y.dtype(), Rel.EQ, self.name, TypeError)
|
|
|
|
|
validator.check('x shape', x.shape(), 'y shape', y.shape(), Rel.EQ, self.name)
|
|
|
|
|
validator.check_value_type('x', x, Tensor, self.name)
|
|
|
|
|
validator.check_value_type('y', y, Tensor, self.name)
|
|
|
|
|
validator.check('x dtype', x.dtype, 'y dtype', y.dtype, Rel.EQ, self.name, TypeError)
|
|
|
|
|
validator.check('x shape', x.shape, 'y shape', y.shape, Rel.EQ, self.name)
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, y):
|
|
|
|
@ -187,7 +187,7 @@ class Cast(PrimitiveWithInfer):
|
|
|
|
|
|
|
|
|
|
def check_elim(self, x, dtype):
|
|
|
|
|
if isinstance(x, Tensor):
|
|
|
|
|
if x.dtype() == dtype:
|
|
|
|
|
if x.dtype == dtype:
|
|
|
|
|
return (True, x)
|
|
|
|
|
return (False, None)
|
|
|
|
|
raise ValueError("Expecting (Tensor, dtype), got : {}".format(inputs))
|
|
|
|
@ -498,7 +498,7 @@ class GatherV2(PrimitiveWithInfer):
|
|
|
|
|
The original Tensor.
|
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Must be in the range
|
|
|
|
|
`[0, input_param.shape()[axis])`.
|
|
|
|
|
`[0, input_param.shape[axis])`.
|
|
|
|
|
- **axis** (int) - Specifies the dimension index to gather indices.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
@ -542,7 +542,7 @@ class SparseGatherV2(GatherV2):
|
|
|
|
|
The original Tensor.
|
|
|
|
|
- **input_indices** (Tensor) - The shape of tensor is :math:`(y_1, y_2, ..., y_S)`.
|
|
|
|
|
Specifies the indices of elements of the original Tensor. Must be in the range
|
|
|
|
|
`[0, input_param.shape()[axis])`.
|
|
|
|
|
`[0, input_param.shape[axis])`.
|
|
|
|
|
- **axis** (int) - Specifies the dimension index to gather indices.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
@ -700,7 +700,7 @@ class Split(PrimitiveWithInfer):
|
|
|
|
|
output_num (int): The number of output tensors. Default: 1.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())),
|
|
|
|
|
ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)),
|
|
|
|
|
or if the output_num is less than or equal to 0, or if the
|
|
|
|
|
dimension which to split cannot be evenly divided by output_num.
|
|
|
|
|
|
|
|
|
@ -1644,7 +1644,7 @@ class Unpack(PrimitiveWithInfer):
|
|
|
|
|
A tuple of Tensors, the shape of each objects is same.
|
|
|
|
|
|
|
|
|
|
Raises:
|
|
|
|
|
ValueError: If axis is out of the range [-len(input_x.shape()), len(input_x.shape())).
|
|
|
|
|
ValueError: If axis is out of the range [-len(input_x.shape), len(input_x.shape)).
|
|
|
|
|
|
|
|
|
|
Examples:
|
|
|
|
|
>>> unpack = P.Unpack()
|
|
|
|
@ -1850,7 +1850,7 @@ class StridedSlice(PrimitiveWithInfer):
|
|
|
|
|
>>> [[5, 5, 5], [6, 6, 6]]], mindspore.float32)
|
|
|
|
|
>>> slice = P.StridedSlice()
|
|
|
|
|
>>> output = slice(input_x, (1, 0, 0), (2, 1, 3), (1, 1, 1))
|
|
|
|
|
>>> output.shape()
|
|
|
|
|
>>> output.shape
|
|
|
|
|
(1, 1, 3)
|
|
|
|
|
>>> output
|
|
|
|
|
[[[3, 3, 3]]]
|
|
|
|
@ -1974,7 +1974,7 @@ class Diag(PrimitiveWithInfer):
|
|
|
|
|
if x is None:
|
|
|
|
|
return None
|
|
|
|
|
# do constant-folding only when x rank is 1
|
|
|
|
|
if len(x.shape()) != 1:
|
|
|
|
|
if len(x.shape) != 1:
|
|
|
|
|
return None
|
|
|
|
|
ret = np.diag(x.asnumpy())
|
|
|
|
|
return Tensor(ret)
|
|
|
|
@ -2026,7 +2026,7 @@ class DiagPart(PrimitiveWithInfer):
|
|
|
|
|
if x is None:
|
|
|
|
|
return None
|
|
|
|
|
# do constant-folding only when x rank is 2
|
|
|
|
|
if len(x.shape()) != 2:
|
|
|
|
|
if len(x.shape) != 2:
|
|
|
|
|
return None
|
|
|
|
|
ret = np.diag(x.asnumpy())
|
|
|
|
|
return Tensor(ret)
|
|
|
|
|