pull/7981/head
lihongkang 4 years ago
parent 88aaec279c
commit 3704e05370

@ -75,11 +75,12 @@ class Dropout(Cell):
Examples: Examples:
>>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32) >>> x = Tensor(np.ones([2, 2, 3]), mindspore.float32)
>>> net = nn.Dropout(keep_prob=0.8) >>> net = nn.Dropout(keep_prob=0.8)
>>> net.set_train()
>>> net(x) >>> net(x)
[[[1.0, 1.0, 1.0], [[[0., 1.25, 0.],
[1.0, 1.0, 1.0]], [1.25, 1.25, 1.25]],
[[1.0, 1.0, 1.0], [[1.25, 1.25, 1.25],
[1.0, 1.0, 1.0]]] [1.25, 1.25, 1.25]]]
""" """
def __init__(self, keep_prob=0.5, dtype=mstype.float32): def __init__(self, keep_prob=0.5, dtype=mstype.float32):
@ -287,7 +288,8 @@ class ClipByNorm(Cell):
>>> net = nn.ClipByNorm() >>> net = nn.ClipByNorm()
>>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32) >>> input = Tensor(np.random.randint(0, 10, [4, 16]), mindspore.float32)
>>> clip_norm = Tensor(np.array([100]).astype(np.float32)) >>> clip_norm = Tensor(np.array([100]).astype(np.float32))
>>> net(input, clip_norm) >>> net(input, clip_norm).shape
(4, 16)
""" """

@ -447,6 +447,8 @@ class CentralCrop(Cell):
>>> net = nn.CentralCrop(central_fraction=0.5) >>> net = nn.CentralCrop(central_fraction=0.5)
>>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32) >>> image = Tensor(np.random.random((4, 3, 4, 4)), mindspore.float32)
>>> output = net(image) >>> output = net(image)
>>> output.shape
(4, 3, 2, 2)
""" """
def __init__(self, central_fraction): def __init__(self, central_fraction):

@ -1941,7 +1941,7 @@ class Slice(PrimitiveWithInfer):
""" """
Slices a tensor in the specified shape. Slices a tensor in the specified shape.
Args: Inputs:
x (Tensor): The target tensor. x (Tensor): The target tensor.
begin (tuple): The beginning of the slice. Only constant value is allowed. begin (tuple): The beginning of the slice. Only constant value is allowed.
size (tuple): The size of the slice. Only constant value is allowed. size (tuple): The size of the slice. Only constant value is allowed.
@ -2262,7 +2262,7 @@ class StridedSlice(PrimitiveWithInfer):
validator.check_value_type("strides", strides_v, [tuple], self.name) validator.check_value_type("strides", strides_v, [tuple], self.name)
if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)): if tuple(filter(lambda x: not isinstance(x, int), begin_v + end_v + strides_v)):
raise ValueError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, " raise TypeError(f"For {self.name}, both the begins, ends, and strides must be a tuple of int, "
f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.") f"but got begins: {begin_v}, ends: {end_v}, strides: {strides_v}.")
if tuple(filter(lambda x: x == 0, strides_v)): if tuple(filter(lambda x: x == 0, strides_v)):

@ -724,11 +724,29 @@ class BatchMatMul(MatMul):
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32) >>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul() >>> batmatmul = P.BatchMatMul()
>>> output = batmatmul(input_x, input_y) >>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
>>> >>>
>>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32) >>> input_x = Tensor(np.ones(shape=[2, 4, 3, 1]), mindspore.float32)
>>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32) >>> input_y = Tensor(np.ones(shape=[2, 4, 3, 4]), mindspore.float32)
>>> batmatmul = P.BatchMatMul(transpose_a=True) >>> batmatmul = P.BatchMatMul(transpose_a=True)
>>> output = batmatmul(input_x, input_y) >>> output = batmatmul(input_x, input_y)
[[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]
[[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]
[[3. 3. 3. 3.]]]]
""" """
@prim_attr_register @prim_attr_register

Loading…
Cancel
Save