add linear and xwplusb

pull/8227/head
jiangzhenguang 4 years ago
parent 244b7034e8
commit 084f428f9d

@ -124,13 +124,13 @@ def clip_by_global_norm(x, clip_norm=1.0, use_norm=None):
>>> x2 = np.array([[1., 4.],[3., 1.]]).astype(np.float32)
>>> input_x = (Tensor(x1), Tensor(x2))
>>> out = clip_by_global_norm(input_x, 1.0)
>>> print(out)
([[ 2.98142403e-01, 4.47213590e-01],
[ 1.49071202e-01, 2.98142403e-01]],
[[ 1.49071202e-01, 5.96284807e-01],
[ 4.47213590e-01, 1.49071202e-01]])
"""
clip_norm = _check_value(clip_norm)
out = _ClipByGlobalNorm(clip_norm, use_norm)(x)
return out
clip_val = _ClipByGlobalNorm(clip_norm, use_norm)(x)
return clip_val

@ -41,22 +41,22 @@ def count_nonzero(x, axis=(), keep_dims=False, dtype=mstype.int32):
Count number of nonzero elements across axis of input tensor
Args:
- **x** (Tensor[Number]) - Input data is used to count non-zero numbers.
- **axis** (Union[int, tuple(int), list(int)]) - The dimensions to reduce. Only constant value is allowed.
Default: (), reduce all dimensions.
- **keep_dims** (bool) - If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions. Default: False.
- **dtype** (Union[Number, mstype.bool_]) - The data type of the output tensor. Only constant value is allowed.
Default: mstype.int32
x (Union(tuple[Tensor], list[Tensor])): Input data is used to count non-zero numbers.
axis (Union[int, tuple(int), list(int)]): The dimensions to reduce. Only constant value is allowed.
Default: (), reduce all dimensions.
keep_dims (bool): If true, keep these reduced dimensions and the length is 1.
If false, don't keep these dimensions. Default: False.
dtype (Union[Number, mstype.bool_]): The data type of the output tensor. Only constant value is allowed.
Default: mstype.int32
Returns:
Tensor, number of nonzero element. The data type is dtype.
Examples:
>>> input_tensor = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
>>> input_x = Tensor(np.array([[0, 1, 0], [1, 1, 0]]).astype(np.float32))
>>> nonzero_num = count_nonzero(x=input_x, axis=[0, 1], keep_dims=True, dtype=mstype.int32)
nonzero_num: [[3]]
>>> print(nonzero_num)
[[3]]
"""
const_utils.check_valid_type(F.dtype(x), mstype.number_type, 'input x')

Loading…
Cancel
Save