|
|
|
@ -633,7 +633,7 @@ class TruncatedNormal(PrimitiveWithInfer):
|
|
|
|
|
dtype (:class:`mindspore.dtype`): Data type. Default: mindspore.float32.
|
|
|
|
|
|
|
|
|
|
Inputs:
|
|
|
|
|
- **shape** (Tensor) - Shape of output tensor. The shape is a 1-D tensor, and type is int.
|
|
|
|
|
- **shape** (tuple[int]) - Shape of output tensor, is a tuple of positive int.
|
|
|
|
|
|
|
|
|
|
Outputs:
|
|
|
|
|
Tensor, type of output tensor is same as attribute `dtype`.
|
|
|
|
@ -651,16 +651,10 @@ class TruncatedNormal(PrimitiveWithInfer):
|
|
|
|
|
validator.check_typename('dtype', dtype, mstype.number_type)
|
|
|
|
|
|
|
|
|
|
def __infer__(self, shape):
|
|
|
|
|
shape_t = shape['value']
|
|
|
|
|
validator.check_subclass("shape", shape['dtype'], mstype.tensor)
|
|
|
|
|
shape_n = shape_t.asnumpy()
|
|
|
|
|
if shape_n.ndim != 1:
|
|
|
|
|
raise ValueError('The rank of input shape must be 1.')
|
|
|
|
|
if shape_n.dtype not in (np.int32, np.int64):
|
|
|
|
|
raise TypeError('The type of input shape must be int32 or int64.')
|
|
|
|
|
for i, item in enumerate(shape_n):
|
|
|
|
|
validator.check_integer(f"shape[{i}]", item.item(), 0, Rel.GT)
|
|
|
|
|
out = {'shape': tuple(shape_n),
|
|
|
|
|
shape_value = shape['value']
|
|
|
|
|
for i, value in enumerate(shape_value):
|
|
|
|
|
validator.check_integer(f'{i}th value of shape', value, 0, Rel.GT)
|
|
|
|
|
out = {'shape': shape_value,
|
|
|
|
|
'dtype': mstype.tensor_type(self.dtype),
|
|
|
|
|
'value': None}
|
|
|
|
|
return out
|
|
|
|
@ -1648,20 +1642,19 @@ class StridedSlice(PrimitiveWithInfer):
|
|
|
|
|
validator.check_type('shrink_axis_mask', shrink_axis_mask, [int])
|
|
|
|
|
|
|
|
|
|
def __infer__(self, x, begin, end, strides):
|
|
|
|
|
begin_shape, end_shape, strides_shape = begin['shape'], end['shape'], strides['shape']
|
|
|
|
|
if begin_shape != strides_shape or end_shape != strides_shape:
|
|
|
|
|
raise ValueError("The shape of begin, end and strides in 'StridedSlice' must be equal.")
|
|
|
|
|
|
|
|
|
|
validator.check_const_input("begin", begin['value'])
|
|
|
|
|
validator.check_const_input("end", end['value'])
|
|
|
|
|
validator.check_const_input("strides", strides['value'])
|
|
|
|
|
x_shape = x['shape']
|
|
|
|
|
x_shp_len = len(x_shape)
|
|
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
|
|
|
|
validator.check_const_input("begin", begin_v)
|
|
|
|
|
validator.check_const_input("end", end_v)
|
|
|
|
|
validator.check_const_input("strides", strides_v)
|
|
|
|
|
validator.check_type("begin", begin['value'], [tuple])
|
|
|
|
|
validator.check_type("end", end['value'], [tuple])
|
|
|
|
|
validator.check_type("strides", strides['value'], [tuple])
|
|
|
|
|
if len(begin_v) != x_shp_len or len(end_v) != x_shp_len or len(strides_v) != x_shp_len:
|
|
|
|
|
raise ValueError(f"The length of begin index{begin_v}, end index{end_v} and strides{strides_v} "
|
|
|
|
|
f"must be equal to the dims({x_shp_len}) of input.")
|
|
|
|
|
|
|
|
|
|
x_shape = x['shape']
|
|
|
|
|
x_shp_len = len(x_shape)
|
|
|
|
|
begin_v, end_v, strides_v = begin['value'], end['value'], strides['value']
|
|
|
|
|
ret_shape = []
|
|
|
|
|
append_dimensions = []
|
|
|
|
|
shrink_pos = bin(self.shrink_axis_mask)[::-1]
|
|
|
|
|